1use std::{
2 borrow::{Borrow, BorrowMut},
3 ops::Deref,
4 sync::{Arc, Mutex},
5};
6
7use openvm_circuit::{
8 arch::{ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PcIncOrSet},
9 system::memory::{
10 offline_checker::{MemoryBridge, MemoryWriteAuxCols},
11 MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId,
12 },
13};
14use openvm_circuit_primitives::{
15 utils::next_power_of_two_or_zero,
16 var_range::{
17 SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
18 },
19};
20use openvm_circuit_primitives_derive::AlignedBorrow;
21use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
22use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode};
23use openvm_stark_backend::{
24 config::{StarkGenericConfig, Val},
25 interaction::InteractionBuilder,
26 p3_air::{Air, AirBuilder, BaseAir},
27 p3_field::{Field, FieldAlgebra, PrimeField32},
28 p3_matrix::{dense::RowMajorMatrix, Matrix},
29 p3_maybe_rayon::prelude::*,
30 prover::types::AirProofInput,
31 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
32 AirRef, Chip, ChipUsageGetter,
33};
34use serde::{Deserialize, Serialize};
35use static_assertions::const_assert_eq;
36use AS::Native;
37
38#[cfg(test)]
39mod tests;
40
41#[repr(C)]
42#[derive(AlignedBorrow)]
43struct JalRangeCheckCols<T> {
44 is_jal: T,
45 is_range_check: T,
46 a_pointer: T,
47 state: ExecutionState<T>,
48 writes_aux: MemoryWriteAuxCols<T, 1>,
50 b: T,
51 c: T,
53 y: T,
55}
56
57const OVERALL_WIDTH: usize = JalRangeCheckCols::<u8>::width();
58const_assert_eq!(OVERALL_WIDTH, 12);
59
60#[derive(Copy, Clone, Debug)]
61pub struct JalRangeCheckAir {
62 execution_bridge: ExecutionBridge,
63 memory_bridge: MemoryBridge,
64 range_bus: VariableRangeCheckerBus,
65}
66
67impl<F: Field> BaseAir<F> for JalRangeCheckAir {
68 fn width(&self) -> usize {
69 OVERALL_WIDTH
70 }
71}
72
73impl<F: Field> BaseAirWithPublicValues<F> for JalRangeCheckAir {}
74impl<F: Field> PartitionedBaseAir<F> for JalRangeCheckAir {}
75impl<AB: InteractionBuilder> Air<AB> for JalRangeCheckAir
76where
77 AB::F: PrimeField32,
78{
79 fn eval(&self, builder: &mut AB) {
80 let main = builder.main();
81 let local = main.row_slice(0);
82 let local_slice = local.deref();
83 let local: &JalRangeCheckCols<AB::Var> = local_slice.borrow();
84 builder.assert_bool(local.is_jal);
85 builder.assert_bool(local.is_range_check);
86 let is_valid = local.is_jal + local.is_range_check;
87 builder.assert_bool(is_valid.clone());
88
89 let d = AB::Expr::from_canonical_u32(Native as u32);
90 let a_val = local.writes_aux.prev_data()[0];
91 let write_val = local.is_jal
93 * (local.state.pc + AB::Expr::from_canonical_u32(DEFAULT_PC_STEP))
94 + local.is_range_check * a_val;
95 self.memory_bridge
96 .write(
97 MemoryAddress::new(d.clone(), local.a_pointer),
98 [write_val],
99 local.state.timestamp,
100 &local.writes_aux,
101 )
102 .eval(builder, is_valid.clone());
103
104 let opcode = local.is_jal
105 * AB::F::from_canonical_usize(NativeJalOpcode::JAL.global_opcode().as_usize())
106 + local.is_range_check
107 * AB::F::from_canonical_usize(
108 NativeRangeCheckOpcode::RANGE_CHECK
109 .global_opcode()
110 .as_usize(),
111 );
112 let pc_inc = local.is_jal * local.b
114 + local.is_range_check * AB::F::from_canonical_u32(DEFAULT_PC_STEP);
115 builder.when(local.is_jal).assert_zero(local.c);
116 self.execution_bridge
117 .execute_and_increment_or_set_pc(
118 opcode,
119 [local.a_pointer.into(), local.b.into(), local.c.into(), d],
120 local.state,
121 AB::F::ONE,
122 PcIncOrSet::Inc(pc_inc),
123 )
124 .eval(builder, is_valid);
125
126 let x = a_val - local.y * AB::Expr::from_canonical_u32(1 << 16);
129 self.range_bus
130 .send(x.clone(), local.b)
131 .eval(builder, local.is_range_check);
132 self.range_bus
134 .send(local.y, local.c)
135 .eval(builder, local.is_range_check);
136 }
137}
138
139impl JalRangeCheckAir {
140 fn new(
141 execution_bridge: ExecutionBridge,
142 memory_bridge: MemoryBridge,
143 range_bus: VariableRangeCheckerBus,
144 ) -> Self {
145 Self {
146 execution_bridge,
147 memory_bridge,
148 range_bus,
149 }
150 }
151}
152
153#[repr(C)]
154#[derive(Serialize, Deserialize)]
155pub struct JalRangeCheckRecord {
156 pub state: ExecutionState<u32>,
157 pub a_rw: RecordId,
158 pub b: u32,
159 pub c: u8,
160 pub is_jal: bool,
161}
162
163pub struct JalRangeCheckChip<F> {
166 air: JalRangeCheckAir,
167 pub records: Vec<JalRangeCheckRecord>,
168 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
169 range_checker_chip: SharedVariableRangeCheckerChip,
170 debug: bool,
172}
173
174impl<F: PrimeField32> JalRangeCheckChip<F> {
175 pub fn new(
176 execution_bridge: ExecutionBridge,
177 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
178 range_checker_chip: SharedVariableRangeCheckerChip,
179 ) -> Self {
180 let memory_bridge = offline_memory.lock().unwrap().memory_bridge();
181 let air = JalRangeCheckAir::new(execution_bridge, memory_bridge, range_checker_chip.bus());
182 Self {
183 air,
184 records: vec![],
185 offline_memory,
186 range_checker_chip,
187 debug: false,
188 }
189 }
190 pub fn with_debug(mut self) -> Self {
191 self.debug = true;
192 self
193 }
194}
195
196impl<F: PrimeField32> InstructionExecutor<F> for JalRangeCheckChip<F> {
197 fn execute(
198 &mut self,
199 memory: &mut MemoryController<F>,
200 instruction: &Instruction<F>,
201 from_state: ExecutionState<u32>,
202 ) -> Result<ExecutionState<u32>, ExecutionError> {
203 if instruction.opcode == NativeJalOpcode::JAL.global_opcode() {
204 let (record_id, _) = memory.write(
205 F::from_canonical_u32(AS::Native as u32),
206 instruction.a,
207 [F::from_canonical_u32(from_state.pc + DEFAULT_PC_STEP)],
208 );
209 let b = instruction.b.as_canonical_u32();
210 self.records.push(JalRangeCheckRecord {
211 state: from_state,
212 a_rw: record_id,
213 b,
214 c: 0,
215 is_jal: true,
216 });
217 return Ok(ExecutionState {
218 pc: (F::from_canonical_u32(from_state.pc) + instruction.b).as_canonical_u32(),
219 timestamp: memory.timestamp(),
220 });
221 } else if instruction.opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() {
222 let d = F::from_canonical_u32(AS::Native as u32);
223 let a_val = memory.unsafe_read_cell(d, instruction.a);
225 let (record_id, _) = memory.write(d, instruction.a, [a_val]);
226 let a_val = a_val.as_canonical_u32();
227 let b = instruction.b.as_canonical_u32();
228 let c = instruction.c.as_canonical_u32();
229 debug_assert!(!self.debug || b <= 16);
230 debug_assert!(!self.debug || c <= 14);
231 let x = a_val & ((1 << 16) - 1);
232 if !self.debug && x >= 1 << b {
233 return Err(ExecutionError::Fail { pc: from_state.pc });
234 }
235 let y = a_val >> 16;
236 if !self.debug && y >= 1 << c {
237 return Err(ExecutionError::Fail { pc: from_state.pc });
238 }
239 self.records.push(JalRangeCheckRecord {
240 state: from_state,
241 a_rw: record_id,
242 b,
243 c: c as u8,
244 is_jal: false,
245 });
246 return Ok(ExecutionState {
247 pc: from_state.pc + DEFAULT_PC_STEP,
248 timestamp: memory.timestamp(),
249 });
250 }
251 panic!("Unknown opcode {}", instruction.opcode);
252 }
253
254 fn get_opcode_name(&self, opcode: usize) -> String {
255 let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize();
256 let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK
257 .global_opcode()
258 .as_usize();
259 if opcode == jal_opcode {
260 return String::from("JAL");
261 }
262 if opcode == range_check_opcode {
263 return String::from("RANGE_CHECK");
264 }
265 panic!("Unknown opcode {}", opcode);
266 }
267}
268
269impl<F: Field> ChipUsageGetter for JalRangeCheckChip<F> {
270 fn air_name(&self) -> String {
271 "JalRangeCheck".to_string()
272 }
273
274 fn current_trace_height(&self) -> usize {
275 self.records.len()
276 }
277
278 fn trace_width(&self) -> usize {
279 OVERALL_WIDTH
280 }
281}
282
283impl<SC: StarkGenericConfig> Chip<SC> for JalRangeCheckChip<Val<SC>>
284where
285 Val<SC>: PrimeField32,
286{
287 fn air(&self) -> AirRef<SC> {
288 Arc::new(self.air)
289 }
290 fn generate_air_proof_input(self) -> AirProofInput<SC> {
291 let height = next_power_of_two_or_zero(self.records.len());
292 let mut flat_trace = Val::<SC>::zero_vec(OVERALL_WIDTH * height);
293 let memory = self.offline_memory.lock().unwrap();
294 let aux_cols_factory = memory.aux_cols_factory();
295
296 self.records
297 .into_par_iter()
298 .zip(flat_trace.par_chunks_mut(OVERALL_WIDTH))
299 .for_each(|(record, slice)| {
300 record_to_row(
301 record,
302 &aux_cols_factory,
303 self.range_checker_chip.as_ref(),
304 slice,
305 &memory,
306 );
307 });
308
309 let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH);
310 AirProofInput::simple_no_pis(matrix)
311 }
312}
313
314fn record_to_row<F: PrimeField32>(
315 record: JalRangeCheckRecord,
316 aux_cols_factory: &MemoryAuxColsFactory<F>,
317 range_checker_chip: &VariableRangeCheckerChip,
318 slice: &mut [F],
319 memory: &OfflineMemory<F>,
320) {
321 let a_record = memory.record_by_id(record.a_rw);
322 let col: &mut JalRangeCheckCols<_> = slice.borrow_mut();
323 col.is_jal = F::from_bool(record.is_jal);
324 col.is_range_check = F::from_bool(!record.is_jal);
325 col.a_pointer = a_record.pointer;
326 col.state = ExecutionState {
327 pc: F::from_canonical_u32(record.state.pc),
328 timestamp: F::from_canonical_u32(record.state.timestamp),
329 };
330 aux_cols_factory.generate_write_aux(a_record, &mut col.writes_aux);
331 col.b = F::from_canonical_u32(record.b);
332 if !record.is_jal {
333 let a_val = a_record.data_at(0);
334 let a_val_u32 = a_val.as_canonical_u32();
335 let y = a_val_u32 >> 16;
336 let x = a_val_u32 & ((1 << 16) - 1);
337 range_checker_chip.add_count(x, record.b as usize);
338 range_checker_chip.add_count(y, record.c as usize);
339 col.c = F::from_canonical_u32(record.c as u32);
340 col.y = F::from_canonical_u32(y);
341 }
342}