openvm_circuit/system/connector/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4    sync::Arc,
5};
6
7use openvm_circuit_primitives::var_range::{
8    SharedVariableRangeCheckerChip, VariableRangeCheckerBus,
9};
10use openvm_circuit_primitives_derive::AlignedBorrow;
11use openvm_instructions::LocalOpcode;
12use openvm_stark_backend::{
13    config::{StarkGenericConfig, Val},
14    interaction::InteractionBuilder,
15    p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder},
16    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
17    p3_matrix::{dense::RowMajorMatrix, Matrix},
18    prover::{cpu::CpuBackend, types::AirProvingContext},
19    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
20    Chip, ChipUsageGetter,
21};
22use serde::{Deserialize, Serialize};
23
24use crate::{
25    arch::{instructions::SystemOpcode::TERMINATE, ExecutionBus, ExecutionState},
26    system::program::ProgramBus,
27};
28
29#[cfg(test)]
30mod tests;
31
32/// When a program hasn't terminated. There is no constraints on the exit code.
33/// But we will use this value when generating the proof.
34pub const DEFAULT_SUSPEND_EXIT_CODE: u32 = 42;
35
36#[derive(Debug, Clone, Copy)]
37pub struct VmConnectorAir {
38    pub execution_bus: ExecutionBus,
39    pub program_bus: ProgramBus,
40    pub range_bus: VariableRangeCheckerBus,
41    /// The final timestamp will be constrained to be in the range [0, 2^timestamp_max_bits).
42    timestamp_max_bits: usize,
43}
44
45#[derive(Debug, Clone, Copy, AlignedBorrow)]
46#[repr(C)]
47pub struct VmConnectorPvs<F> {
48    /// The initial PC of this segment.
49    pub initial_pc: F,
50    /// The final PC of this segment.
51    pub final_pc: F,
52    /// The exit code of the whole program. 0 means exited normally. This is only meaningful when
53    /// `is_terminate` is 1.
54    pub exit_code: F,
55    /// Whether the whole program is terminated. 0 means not terminated. 1 means terminated.
56    /// Only the last segment of an execution can have `is_terminate` = 1.
57    pub is_terminate: F,
58}
59
60impl<F: PrimeField32> VmConnectorPvs<F> {
61    pub fn is_terminate(&self) -> bool {
62        self.is_terminate == F::from_bool(true)
63    }
64
65    pub fn exit_code(&self) -> Option<u32> {
66        if self.is_terminate() && self.exit_code == F::ZERO {
67            Some(self.exit_code.as_canonical_u32())
68        } else {
69            None
70        }
71    }
72}
73
74impl<F: Field> BaseAirWithPublicValues<F> for VmConnectorAir {
75    fn num_public_values(&self) -> usize {
76        VmConnectorPvs::<F>::width()
77    }
78}
79impl<F: Field> PartitionedBaseAir<F> for VmConnectorAir {}
80impl<F: Field> BaseAir<F> for VmConnectorAir {
81    fn width(&self) -> usize {
82        5
83    }
84
85    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
86        Some(RowMajorMatrix::new_col(vec![F::ZERO, F::ONE]))
87    }
88}
89
90impl VmConnectorAir {
91    pub fn new(
92        execution_bus: ExecutionBus,
93        program_bus: ProgramBus,
94        range_bus: VariableRangeCheckerBus,
95        timestamp_max_bits: usize,
96    ) -> Self {
97        assert!(
98            range_bus.range_max_bits * 2 >= timestamp_max_bits,
99            "Range checker not large enough: range_max_bits={}, timestamp_max_bits={}",
100            range_bus.range_max_bits,
101            timestamp_max_bits
102        );
103        Self {
104            execution_bus,
105            program_bus,
106            range_bus,
107            timestamp_max_bits,
108        }
109    }
110
111    /// Returns (low_bits, high_bits) to range check.
112    fn timestamp_limb_bits(&self) -> (usize, usize) {
113        let range_max_bits = self.range_bus.range_max_bits;
114        if self.timestamp_max_bits <= range_max_bits {
115            (self.timestamp_max_bits, 0)
116        } else {
117            (range_max_bits, self.timestamp_max_bits - range_max_bits)
118        }
119    }
120}
121
122#[derive(Debug, Copy, Clone, AlignedBorrow, Serialize, Deserialize)]
123#[repr(C)]
124pub struct ConnectorCols<T> {
125    pub pc: T,
126    pub timestamp: T,
127    pub is_terminate: T,
128    pub exit_code: T,
129    /// Lowest `range_bus.range_max_bits` bits of the timestamp
130    timestamp_low_limb: T,
131}
132
133impl<T: Copy> ConnectorCols<T> {
134    fn map<F>(self, f: impl Fn(T) -> F) -> ConnectorCols<F> {
135        ConnectorCols {
136            pc: f(self.pc),
137            timestamp: f(self.timestamp),
138            is_terminate: f(self.is_terminate),
139            exit_code: f(self.exit_code),
140            timestamp_low_limb: f(self.timestamp_low_limb),
141        }
142    }
143
144    fn flatten(&self) -> [T; 5] {
145        [
146            self.pc,
147            self.timestamp,
148            self.is_terminate,
149            self.exit_code,
150            self.timestamp_low_limb,
151        ]
152    }
153}
154
155impl<AB: InteractionBuilder + PairBuilder + AirBuilderWithPublicValues> Air<AB> for VmConnectorAir {
156    fn eval(&self, builder: &mut AB) {
157        let main = builder.main();
158        let preprocessed = builder.preprocessed();
159        let prep_local = preprocessed
160            .row_slice(0)
161            .expect("window should have two elements");
162        let (begin, end) = (
163            main.row_slice(0).expect("window should have two elements"),
164            main.row_slice(1).expect("window should have two elements"),
165        );
166
167        let begin: &ConnectorCols<AB::Var> = (*begin).borrow();
168        let end: &ConnectorCols<AB::Var> = (*end).borrow();
169
170        let &VmConnectorPvs {
171            initial_pc,
172            final_pc,
173            exit_code,
174            is_terminate,
175        } = builder.public_values().borrow();
176
177        builder.when_transition().assert_eq(begin.pc, initial_pc);
178        builder.when_transition().assert_eq(end.pc, final_pc);
179        builder
180            .when_transition()
181            .when(end.is_terminate)
182            .assert_eq(end.exit_code, exit_code);
183        builder
184            .when_transition()
185            .assert_eq(end.is_terminate, is_terminate);
186        // Assert is_terminate is boolean on every row to ensure lookup multiplicity is boolean
187        // below
188        builder.assert_bool(begin.is_terminate);
189
190        builder.when_transition().assert_one(begin.timestamp);
191
192        self.execution_bus.execute(
193            builder,
194            AB::Expr::ONE - prep_local[0], // 1 only if these are [0th, 1st] and not [1st, 0th]
195            ExecutionState::new(end.pc, end.timestamp),
196            ExecutionState::new(begin.pc, begin.timestamp),
197        );
198        self.program_bus.lookup_instruction(
199            builder,
200            end.pc,
201            AB::Expr::from_usize(TERMINATE.global_opcode().as_usize()),
202            [AB::Expr::ZERO, AB::Expr::ZERO, end.exit_code.into()],
203            (AB::Expr::ONE - prep_local[0]) * end.is_terminate,
204        );
205
206        // The following constraints hold on every row, so we rename `begin` to `local` to avoid
207        // confusion.
208        let local = begin;
209        // We decompose and range check `local.timestamp` as `timestamp_low_limb,
210        // timestamp_high_limb` where `timestamp = timestamp_low_limb + timestamp_high_limb
211        // * 2^range_max_bits`.
212        let (low_bits, high_bits) = self.timestamp_limb_bits();
213        let high_limb = (local.timestamp - local.timestamp_low_limb)
214            * AB::F::ONE.div_2exp_u64(self.range_bus.range_max_bits as u64);
215        self.range_bus
216            .range_check(local.timestamp_low_limb, low_bits)
217            .eval(builder, AB::Expr::ONE);
218        self.range_bus
219            .range_check(high_limb, high_bits)
220            .eval(builder, AB::Expr::ONE);
221    }
222}
223
224pub struct VmConnectorChip<F> {
225    pub range_checker: SharedVariableRangeCheckerChip,
226    pub boundary_states: [Option<ConnectorCols<u32>>; 2],
227    timestamp_max_bits: usize,
228    _marker: PhantomData<F>,
229}
230
231impl<F> VmConnectorChip<F> {
232    pub fn new(range_checker: SharedVariableRangeCheckerChip, timestamp_max_bits: usize) -> Self {
233        let range_bus = range_checker.bus();
234        assert!(
235            range_bus.range_max_bits * 2 >= timestamp_max_bits,
236            "Range checker not large enough: range_max_bits={}, timestamp_max_bits={}",
237            range_bus.range_max_bits,
238            timestamp_max_bits
239        );
240        Self {
241            range_checker,
242            boundary_states: [None, None],
243            timestamp_max_bits,
244            _marker: PhantomData,
245        }
246    }
247
248    pub fn begin(&mut self, state: ExecutionState<u32>) {
249        self.boundary_states[0] = Some(ConnectorCols {
250            pc: state.pc,
251            timestamp: state.timestamp,
252            is_terminate: 0,
253            exit_code: 0,
254            timestamp_low_limb: 0, // will be computed during tracegen
255        });
256    }
257
258    pub fn end(&mut self, state: ExecutionState<u32>, exit_code: Option<u32>) {
259        self.boundary_states[1] = Some(ConnectorCols {
260            pc: state.pc,
261            timestamp: state.timestamp,
262            is_terminate: exit_code.is_some() as u32,
263            exit_code: exit_code.unwrap_or(DEFAULT_SUSPEND_EXIT_CODE),
264            timestamp_low_limb: 0, // will be computed during tracegen
265        });
266    }
267
268    fn timestamp_limb_bits(&self) -> (usize, usize) {
269        let range_max_bits = self.range_checker.bus().range_max_bits;
270        if self.timestamp_max_bits <= range_max_bits {
271            (self.timestamp_max_bits, 0)
272        } else {
273            (range_max_bits, self.timestamp_max_bits - range_max_bits)
274        }
275    }
276}
277
278impl<RA, SC> Chip<RA, CpuBackend<SC>> for VmConnectorChip<Val<SC>>
279where
280    SC: StarkGenericConfig,
281    Val<SC>: PrimeField32,
282{
283    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
284        let [initial_state, final_state] = self.boundary_states.map(|state| {
285            let mut state = state.unwrap();
286            // Decompose and range check timestamp
287            let range_max_bits = self.range_checker.range_max_bits();
288            let timestamp_low_limb = state.timestamp & ((1u32 << range_max_bits) - 1);
289            state.timestamp_low_limb = timestamp_low_limb;
290            let (low_bits, high_bits) = self.timestamp_limb_bits();
291            self.range_checker.add_count(timestamp_low_limb, low_bits);
292            self.range_checker
293                .add_count(state.timestamp >> range_max_bits, high_bits);
294
295            state.map(Val::<SC>::from_u32)
296        });
297
298        let trace = Arc::new(RowMajorMatrix::new(
299            [initial_state.flatten(), final_state.flatten()].concat(),
300            self.trace_width(),
301        ));
302
303        let mut public_values = Val::<SC>::zero_vec(VmConnectorPvs::<Val<SC>>::width());
304        *public_values.as_mut_slice().borrow_mut() = VmConnectorPvs {
305            initial_pc: initial_state.pc,
306            final_pc: final_state.pc,
307            exit_code: final_state.exit_code,
308            is_terminate: final_state.is_terminate,
309        };
310        AirProvingContext::simple(trace, public_values)
311    }
312}
313
314impl<F: PrimeField32> ChipUsageGetter for VmConnectorChip<F> {
315    fn air_name(&self) -> String {
316        "VmConnectorAir".to_string()
317    }
318
319    fn constant_trace_height(&self) -> Option<usize> {
320        Some(2)
321    }
322
323    fn current_trace_height(&self) -> usize {
324        2
325    }
326
327    fn trace_width(&self) -> usize {
328        5
329    }
330}