openvm_native_circuit/jal/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    ops::Deref,
4    sync::{Arc, Mutex},
5};
6
7use openvm_circuit::{
8    arch::{ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PcIncOrSet},
9    system::memory::{
10        offline_checker::{MemoryBridge, MemoryWriteAuxCols},
11        MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId,
12    },
13};
14use openvm_circuit_primitives::{
15    utils::next_power_of_two_or_zero,
16    var_range::{
17        SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
18    },
19};
20use openvm_circuit_primitives_derive::AlignedBorrow;
21use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
22use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode};
23use openvm_stark_backend::{
24    config::{StarkGenericConfig, Val},
25    interaction::InteractionBuilder,
26    p3_air::{Air, AirBuilder, BaseAir},
27    p3_field::{Field, FieldAlgebra, PrimeField32},
28    p3_matrix::{dense::RowMajorMatrix, Matrix},
29    p3_maybe_rayon::prelude::*,
30    prover::types::AirProofInput,
31    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
32    AirRef, Chip, ChipUsageGetter,
33};
34use serde::{Deserialize, Serialize};
35use static_assertions::const_assert_eq;
36use AS::Native;
37
38#[cfg(test)]
39mod tests;
40
41#[repr(C)]
42#[derive(AlignedBorrow)]
43struct JalRangeCheckCols<T> {
44    is_jal: T,
45    is_range_check: T,
46    a_pointer: T,
47    state: ExecutionState<T>,
48    // Write when is_jal, read when is_range_check.
49    writes_aux: MemoryWriteAuxCols<T, 1>,
50    b: T,
51    // Only used by range check.
52    c: T,
53    // Only used by range check.
54    y: T,
55}
56
57const OVERALL_WIDTH: usize = JalRangeCheckCols::<u8>::width();
58const_assert_eq!(OVERALL_WIDTH, 12);
59
60#[derive(Copy, Clone, Debug)]
61pub struct JalRangeCheckAir {
62    execution_bridge: ExecutionBridge,
63    memory_bridge: MemoryBridge,
64    range_bus: VariableRangeCheckerBus,
65}
66
67impl<F: Field> BaseAir<F> for JalRangeCheckAir {
68    fn width(&self) -> usize {
69        OVERALL_WIDTH
70    }
71}
72
73impl<F: Field> BaseAirWithPublicValues<F> for JalRangeCheckAir {}
74impl<F: Field> PartitionedBaseAir<F> for JalRangeCheckAir {}
75impl<AB: InteractionBuilder> Air<AB> for JalRangeCheckAir
76where
77    AB::F: PrimeField32,
78{
79    fn eval(&self, builder: &mut AB) {
80        let main = builder.main();
81        let local = main.row_slice(0);
82        let local_slice = local.deref();
83        let local: &JalRangeCheckCols<AB::Var> = local_slice.borrow();
84        builder.assert_bool(local.is_jal);
85        builder.assert_bool(local.is_range_check);
86        let is_valid = local.is_jal + local.is_range_check;
87        builder.assert_bool(is_valid.clone());
88
89        let d = AB::Expr::from_canonical_u32(Native as u32);
90        let a_val = local.writes_aux.prev_data()[0];
91        // if is_jal, write pc + DEFAULT_PC_STEP, else if is_range_check, read a_val.
92        let write_val = local.is_jal
93            * (local.state.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP))
94            + local.is_range_check * a_val;
95        self.memory_bridge
96            .write(
97                MemoryAddress::new(d.clone(), local.a_pointer),
98                [write_val],
99                local.state.timestamp,
100                &local.writes_aux,
101            )
102            .eval(builder, is_valid.clone());
103
104        let opcode = local.is_jal
105            * AB::F::from_canonical_usize(NativeJalOpcode::JAL.global_opcode().as_usize())
106            + local.is_range_check
107                * AB::F::from_canonical_usize(
108                    NativeRangeCheckOpcode::RANGE_CHECK
109                        .global_opcode()
110                        .as_usize(),
111                );
112        // Increment pc by b if is_jal, else by DEFAULT_PC_STEP if is_range_check.
113        let pc_inc = local.is_jal * local.b
114            + local.is_range_check * AB::F::from_canonical_u32(DEFAULT_PC_STEP);
115        builder.when(local.is_jal).assert_zero(local.c);
116        self.execution_bridge
117            .execute_and_increment_or_set_pc(
118                opcode,
119                [local.a_pointer.into(), local.b.into(), local.c.into(), d],
120                local.state,
121                AB::F::ONE,
122                PcIncOrSet::Inc(pc_inc),
123            )
124            .eval(builder, is_valid);
125
126        // Range check specific:
127        // a_val = x + y * (1 << 16)
128        let x = a_val - local.y * AB::Expr::from_canonical_u32(1 << 16);
129        self.range_bus
130            .send(x.clone(), local.b)
131            .eval(builder, local.is_range_check);
132        // Assert y < (1 << c), where c <= 14.
133        self.range_bus
134            .send(local.y, local.c)
135            .eval(builder, local.is_range_check);
136    }
137}
138
139impl JalRangeCheckAir {
140    fn new(
141        execution_bridge: ExecutionBridge,
142        memory_bridge: MemoryBridge,
143        range_bus: VariableRangeCheckerBus,
144    ) -> Self {
145        Self {
146            execution_bridge,
147            memory_bridge,
148            range_bus,
149        }
150    }
151}
152
153#[repr(C)]
154#[derive(Serialize, Deserialize)]
155pub struct JalRangeCheckRecord {
156    pub state: ExecutionState<u32>,
157    pub a_rw: RecordId,
158    pub b: u32,
159    pub c: u8,
160    pub is_jal: bool,
161}
162
163/// Chip for JAL and RANGE_CHECK. These opcodes are logically irrelevant. Putting these opcodes into
164/// the same chip is just to save columns.
165pub struct JalRangeCheckChip<F> {
166    air: JalRangeCheckAir,
167    pub records: Vec<JalRangeCheckRecord>,
168    offline_memory: Arc<Mutex<OfflineMemory<F>>>,
169    range_checker_chip: SharedVariableRangeCheckerChip,
170    /// If true, ignore execution errors.
171    debug: bool,
172}
173
174impl<F: PrimeField32> JalRangeCheckChip<F> {
175    pub fn new(
176        execution_bridge: ExecutionBridge,
177        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
178        range_checker_chip: SharedVariableRangeCheckerChip,
179    ) -> Self {
180        let memory_bridge = offline_memory.lock().unwrap().memory_bridge();
181        let air = JalRangeCheckAir::new(execution_bridge, memory_bridge, range_checker_chip.bus());
182        Self {
183            air,
184            records: vec![],
185            offline_memory,
186            range_checker_chip,
187            debug: false,
188        }
189    }
190    pub fn with_debug(mut self) -> Self {
191        self.debug = true;
192        self
193    }
194}
195
196impl<F: PrimeField32> InstructionExecutor<F> for JalRangeCheckChip<F> {
197    fn execute(
198        &mut self,
199        memory: &mut MemoryController<F>,
200        instruction: &Instruction<F>,
201        from_state: ExecutionState<u32>,
202    ) -> Result<ExecutionState<u32>, ExecutionError> {
203        if instruction.opcode == NativeJalOpcode::JAL.global_opcode() {
204            let (record_id, _) = memory.write(
205                F::from_canonical_u32(AS::Native as u32),
206                instruction.a,
207                [F::from_canonical_u32(from_state.pc + DEFAULT_PC_STEP)],
208            );
209            let b = instruction.b.as_canonical_u32();
210            self.records.push(JalRangeCheckRecord {
211                state: from_state,
212                a_rw: record_id,
213                b,
214                c: 0,
215                is_jal: true,
216            });
217            return Ok(ExecutionState {
218                pc: (F::from_canonical_u32(from_state.pc) + instruction.b).as_canonical_u32(),
219                timestamp: memory.timestamp(),
220            });
221        } else if instruction.opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() {
222            let d = F::from_canonical_u32(AS::Native as u32);
223            // This is a read, but we make the record have prev_data
224            let a_val = memory.unsafe_read_cell(d, instruction.a);
225            let (record_id, _) = memory.write(d, instruction.a, [a_val]);
226            let a_val = a_val.as_canonical_u32();
227            let b = instruction.b.as_canonical_u32();
228            let c = instruction.c.as_canonical_u32();
229            debug_assert!(!self.debug || b <= 16);
230            debug_assert!(!self.debug || c <= 14);
231            let x = a_val & ((1 << 16) - 1);
232            if !self.debug && x >= 1 << b {
233                return Err(ExecutionError::Fail { pc: from_state.pc });
234            }
235            let y = a_val >> 16;
236            if !self.debug && y >= 1 << c {
237                return Err(ExecutionError::Fail { pc: from_state.pc });
238            }
239            self.records.push(JalRangeCheckRecord {
240                state: from_state,
241                a_rw: record_id,
242                b,
243                c: c as u8,
244                is_jal: false,
245            });
246            return Ok(ExecutionState {
247                pc: from_state.pc + DEFAULT_PC_STEP,
248                timestamp: memory.timestamp(),
249            });
250        }
251        panic!("Unknown opcode {}", instruction.opcode);
252    }
253
254    fn get_opcode_name(&self, opcode: usize) -> String {
255        let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize();
256        let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK
257            .global_opcode()
258            .as_usize();
259        if opcode == jal_opcode {
260            return String::from("JAL");
261        }
262        if opcode == range_check_opcode {
263            return String::from("RANGE_CHECK");
264        }
265        panic!("Unknown opcode {}", opcode);
266    }
267}
268
269impl<F: Field> ChipUsageGetter for JalRangeCheckChip<F> {
270    fn air_name(&self) -> String {
271        "JalRangeCheck".to_string()
272    }
273
274    fn current_trace_height(&self) -> usize {
275        self.records.len()
276    }
277
278    fn trace_width(&self) -> usize {
279        OVERALL_WIDTH
280    }
281}
282
283impl<SC: StarkGenericConfig> Chip<SC> for JalRangeCheckChip<Val<SC>>
284where
285    Val<SC>: PrimeField32,
286{
287    fn air(&self) -> AirRef<SC> {
288        Arc::new(self.air)
289    }
290    fn generate_air_proof_input(self) -> AirProofInput<SC> {
291        let height = next_power_of_two_or_zero(self.records.len());
292        let mut flat_trace = Val::<SC>::zero_vec(OVERALL_WIDTH * height);
293        let memory = self.offline_memory.lock().unwrap();
294        let aux_cols_factory = memory.aux_cols_factory();
295
296        self.records
297            .into_par_iter()
298            .zip(flat_trace.par_chunks_mut(OVERALL_WIDTH))
299            .for_each(|(record, slice)| {
300                record_to_row(
301                    record,
302                    &aux_cols_factory,
303                    self.range_checker_chip.as_ref(),
304                    slice,
305                    &memory,
306                );
307            });
308
309        let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH);
310        AirProofInput::simple_no_pis(matrix)
311    }
312}
313
314fn record_to_row<F: PrimeField32>(
315    record: JalRangeCheckRecord,
316    aux_cols_factory: &MemoryAuxColsFactory<F>,
317    range_checker_chip: &VariableRangeCheckerChip,
318    slice: &mut [F],
319    memory: &OfflineMemory<F>,
320) {
321    let a_record = memory.record_by_id(record.a_rw);
322    let col: &mut JalRangeCheckCols<_> = slice.borrow_mut();
323    col.is_jal = F::from_bool(record.is_jal);
324    col.is_range_check = F::from_bool(!record.is_jal);
325    col.a_pointer = a_record.pointer;
326    col.state = ExecutionState {
327        pc: F::from_canonical_u32(record.state.pc),
328        timestamp: F::from_canonical_u32(record.state.timestamp),
329    };
330    aux_cols_factory.generate_write_aux(a_record, &mut col.writes_aux);
331    col.b = F::from_canonical_u32(record.b);
332    if !record.is_jal {
333        let a_val = a_record.data_at(0);
334        let a_val_u32 = a_val.as_canonical_u32();
335        let y = a_val_u32 >> 16;
336        let x = a_val_u32 & ((1 << 16) - 1);
337        range_checker_chip.add_count(x, record.b as usize);
338        range_checker_chip.add_count(y, record.c as usize);
339        col.c = F::from_canonical_u32(record.c as u32);
340        col.y = F::from_canonical_u32(y);
341    }
342}