openvm_rv32im_circuit/hintstore/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    sync::{Arc, Mutex, OnceLock},
4};
5
6use openvm_circuit::{
7    arch::{
8        ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams,
9    },
10    system::{
11        memory::{
12            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
13            MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId,
14        },
15        program::ProgramBus,
16    },
17};
18use openvm_circuit_primitives::{
19    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
20    utils::{next_power_of_two_or_zero, not},
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24    instruction::Instruction,
25    program::DEFAULT_PC_STEP,
26    riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
27    LocalOpcode,
28};
29use openvm_rv32im_transpiler::{
30    Rv32HintStoreOpcode,
31    Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW},
32};
33use openvm_stark_backend::{
34    config::{StarkGenericConfig, Val},
35    interaction::InteractionBuilder,
36    p3_air::{Air, AirBuilder, BaseAir},
37    p3_field::{Field, FieldAlgebra, PrimeField32},
38    p3_matrix::{dense::RowMajorMatrix, Matrix},
39    prover::types::AirProofInput,
40    rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
41    Chip, ChipUsageGetter,
42};
43use serde::{Deserialize, Serialize};
44
45use crate::adapters::{compose, decompose};
46
47#[cfg(test)]
48mod tests;
49
50#[repr(C)]
51#[derive(AlignedBorrow, Debug)]
52pub struct Rv32HintStoreCols<T> {
53    // common
54    pub is_single: T,
55    pub is_buffer: T,
56    // should be 1 for single
57    pub rem_words_limbs: [T; RV32_REGISTER_NUM_LIMBS],
58
59    pub from_state: ExecutionState<T>,
60    pub mem_ptr_ptr: T,
61    pub mem_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS],
62    pub mem_ptr_aux_cols: MemoryReadAuxCols<T>,
63
64    pub write_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
65    pub data: [T; RV32_REGISTER_NUM_LIMBS],
66
67    // only buffer
68    pub is_buffer_start: T,
69    pub num_words_ptr: T,
70    pub num_words_aux_cols: MemoryReadAuxCols<T>,
71}
72
73#[derive(Copy, Clone, Debug)]
74pub struct Rv32HintStoreAir {
75    pub execution_bridge: ExecutionBridge,
76    pub memory_bridge: MemoryBridge,
77    pub bitwise_operation_lookup_bus: BitwiseOperationLookupBus,
78    pub offset: usize,
79    pointer_max_bits: usize,
80}
81
82impl<F: Field> BaseAir<F> for Rv32HintStoreAir {
83    fn width(&self) -> usize {
84        Rv32HintStoreCols::<F>::width()
85    }
86}
87
88impl<F: Field> BaseAirWithPublicValues<F> for Rv32HintStoreAir {}
89impl<F: Field> PartitionedBaseAir<F> for Rv32HintStoreAir {}
90
91impl<AB: InteractionBuilder> Air<AB> for Rv32HintStoreAir {
92    fn eval(&self, builder: &mut AB) {
93        let main = builder.main();
94        let local = main.row_slice(0);
95        let local_cols: &Rv32HintStoreCols<AB::Var> = (*local).borrow();
96        let next = main.row_slice(1);
97        let next_cols: &Rv32HintStoreCols<AB::Var> = (*next).borrow();
98
99        let timestamp: AB::Var = local_cols.from_state.timestamp;
100        let mut timestamp_delta: usize = 0;
101        let mut timestamp_pp = || {
102            timestamp_delta += 1;
103            timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
104        };
105
106        builder.assert_bool(local_cols.is_single);
107        builder.assert_bool(local_cols.is_buffer);
108        builder.assert_bool(local_cols.is_buffer_start);
109        builder
110            .when(local_cols.is_buffer_start)
111            .assert_one(local_cols.is_buffer);
112        builder.assert_bool(local_cols.is_single + local_cols.is_buffer);
113
114        let is_valid = local_cols.is_single + local_cols.is_buffer;
115        let is_start = local_cols.is_single + local_cols.is_buffer_start;
116        // `is_end` is false iff the next row is a buffer row that is not buffer start
117        // This is boolean because is_buffer_start == 1 => is_buffer == 1
118        // Note: every non-valid row has `is_end == 1`
119        let is_end = not::<AB::Expr>(next_cols.is_buffer) + next_cols.is_buffer_start;
120
121        let mut rem_words = AB::Expr::ZERO;
122        let mut next_rem_words = AB::Expr::ZERO;
123        let mut mem_ptr = AB::Expr::ZERO;
124        let mut next_mem_ptr = AB::Expr::ZERO;
125        for i in (0..RV32_REGISTER_NUM_LIMBS).rev() {
126            rem_words = rem_words * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
127                + local_cols.rem_words_limbs[i];
128            next_rem_words = next_rem_words * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
129                + next_cols.rem_words_limbs[i];
130            mem_ptr = mem_ptr * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
131                + local_cols.mem_ptr_limbs[i];
132            next_mem_ptr = next_mem_ptr * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
133                + next_cols.mem_ptr_limbs[i];
134        }
135
136        // Constrain that if local is invalid, then the next state is invalid as well
137        builder
138            .when_transition()
139            .when(not::<AB::Expr>(is_valid.clone()))
140            .assert_zero(next_cols.is_single + next_cols.is_buffer);
141
142        // Constrain that when we start a buffer, the is_buffer_start is set to 1
143        builder
144            .when(local_cols.is_single)
145            .assert_one(is_end.clone());
146        builder
147            .when_first_row()
148            .assert_one(not::<AB::Expr>(local_cols.is_buffer) + local_cols.is_buffer_start);
149
150        // read mem_ptr
151        self.memory_bridge
152            .read(
153                MemoryAddress::new(
154                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
155                    local_cols.mem_ptr_ptr,
156                ),
157                local_cols.mem_ptr_limbs,
158                timestamp_pp(),
159                &local_cols.mem_ptr_aux_cols,
160            )
161            .eval(builder, is_start.clone());
162
163        // read num_words
164        self.memory_bridge
165            .read(
166                MemoryAddress::new(
167                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
168                    local_cols.num_words_ptr,
169                ),
170                local_cols.rem_words_limbs,
171                timestamp_pp(),
172                &local_cols.num_words_aux_cols,
173            )
174            .eval(builder, local_cols.is_buffer_start);
175
176        // write hint
177        self.memory_bridge
178            .write(
179                MemoryAddress::new(AB::F::from_canonical_u32(RV32_MEMORY_AS), mem_ptr.clone()),
180                local_cols.data,
181                timestamp_pp(),
182                &local_cols.write_aux,
183            )
184            .eval(builder, is_valid.clone());
185
186        let expected_opcode = (local_cols.is_single
187            * AB::F::from_canonical_usize(HINT_STOREW as usize + self.offset))
188            + (local_cols.is_buffer
189                * AB::F::from_canonical_usize(HINT_BUFFER as usize + self.offset));
190
191        self.execution_bridge
192            .execute_and_increment_pc(
193                expected_opcode,
194                [
195                    local_cols.is_buffer * (local_cols.num_words_ptr),
196                    local_cols.mem_ptr_ptr.into(),
197                    AB::Expr::ZERO,
198                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
199                    AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
200                ],
201                local_cols.from_state,
202                rem_words.clone() * AB::F::from_canonical_usize(timestamp_delta),
203            )
204            .eval(builder, is_start.clone());
205
206        // Preventing mem_ptr and rem_words overflow
207        // Constraining mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1] < 2^(pointer_max_bits - (RV32_REGISTER_NUM_LIMBS - 1)*RV32_CELL_BITS)
208        // which implies mem_ptr <= 2^pointer_max_bits
209        // Similarly for rem_words <= 2^pointer_max_bits
210        self.bitwise_operation_lookup_bus
211            .send_range(
212                local_cols.mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1]
213                    * AB::F::from_canonical_usize(
214                        1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
215                    ),
216                local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 1]
217                    * AB::F::from_canonical_usize(
218                        1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
219                    ),
220            )
221            .eval(builder, is_start.clone());
222
223        // Checking that hint is bytes
224        for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
225            self.bitwise_operation_lookup_bus
226                .send_range(local_cols.data[2 * i], local_cols.data[(2 * i) + 1])
227                .eval(builder, is_valid.clone());
228        }
229
230        // buffer transition
231        // `is_end` implies that the next row belongs to a new instruction,
232        // which could be one of empty, hint_single, or hint_buffer
233        // Constrains that when the current row is not empty and `is_end == 1`, then `rem_words` is 1
234        builder
235            .when(is_valid)
236            .when(is_end.clone())
237            .assert_one(rem_words.clone());
238
239        let mut when_buffer_transition = builder.when(not::<AB::Expr>(is_end.clone()));
240        // Notes on `rem_words`: we constrain that `rem_words` doesn't overflow when we first read it and
241        // that on each row it decreases by one (below). We also constrain that when the current instruction ends then `rem_words` is 1.
242        // However, we don't constrain that when `rem_words` is 1 then we have to end the current instruction.
243        // The only way to exploit this if we to do some multiple of `p` number of additional illegal `buffer` rows where `p` is the modulus of `F`.
244        // However, when doing `p` additional `buffer` rows we will always increment `mem_ptr` to an illegal memory address at some point,
245        // which prevents this exploit.
246        when_buffer_transition.assert_one(rem_words.clone() - next_rem_words.clone());
247        // Note: we only care about the `next_mem_ptr = compose(next_mem_ptr_limb)` and not the individual limbs:
248        // the limbs do not need to be in the range, they can be anything to make `next_mem_ptr` correct --
249        // this is just a way to not have to have another column for `mem_ptr`.
250        // The constraint we care about is `next.mem_ptr == local.mem_ptr + 4`.
251        // Finally, since we increment by `4` each time, any out of bounds memory access will be rejected by the memory bus before we overflow the field.
252        when_buffer_transition.assert_eq(
253            next_mem_ptr.clone() - mem_ptr.clone(),
254            AB::F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS),
255        );
256        when_buffer_transition.assert_eq(
257            timestamp + AB::F::from_canonical_usize(timestamp_delta),
258            next_cols.from_state.timestamp,
259        );
260    }
261}
262
263#[derive(Serialize, Deserialize)]
264#[serde(bound = "F: Field")]
265pub struct Rv32HintStoreRecord<F: Field> {
266    pub from_state: ExecutionState<u32>,
267    pub instruction: Instruction<F>,
268    pub mem_ptr_read: RecordId,
269    pub mem_ptr: u32,
270    pub num_words: u32,
271
272    pub num_words_read: Option<RecordId>,
273    pub hints: Vec<([F; RV32_REGISTER_NUM_LIMBS], RecordId)>,
274}
275
276pub struct Rv32HintStoreChip<F: Field> {
277    air: Rv32HintStoreAir,
278    pub records: Vec<Rv32HintStoreRecord<F>>,
279    pub height: usize,
280    offline_memory: Arc<Mutex<OfflineMemory<F>>>,
281    pub streams: OnceLock<Arc<Mutex<Streams<F>>>>,
282    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
283}
284
285impl<F: PrimeField32> Rv32HintStoreChip<F> {
286    pub fn new(
287        execution_bus: ExecutionBus,
288        program_bus: ProgramBus,
289        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
290        memory_bridge: MemoryBridge,
291        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
292        pointer_max_bits: usize,
293        offset: usize,
294    ) -> Self {
295        let air = Rv32HintStoreAir {
296            execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
297            memory_bridge,
298            bitwise_operation_lookup_bus: bitwise_lookup_chip.bus(),
299            offset,
300            pointer_max_bits,
301        };
302        Self {
303            records: vec![],
304            air,
305            height: 0,
306            offline_memory,
307            streams: OnceLock::new(),
308            bitwise_lookup_chip,
309        }
310    }
311    pub fn set_streams(&mut self, streams: Arc<Mutex<Streams<F>>>) {
312        self.streams.set(streams).unwrap();
313    }
314}
315
316impl<F: PrimeField32> InstructionExecutor<F> for Rv32HintStoreChip<F> {
317    fn execute(
318        &mut self,
319        memory: &mut MemoryController<F>,
320        instruction: &Instruction<F>,
321        from_state: ExecutionState<u32>,
322    ) -> Result<ExecutionState<u32>, ExecutionError> {
323        let &Instruction {
324            opcode,
325            a: num_words_ptr,
326            b: mem_ptr_ptr,
327            d,
328            e,
329            ..
330        } = instruction;
331        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
332        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
333        let local_opcode =
334            Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
335
336        let (mem_ptr_read, mem_ptr_limbs) = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, mem_ptr_ptr);
337        let (num_words, num_words_read) = if local_opcode == HINT_STOREW {
338            memory.increment_timestamp();
339            (1, None)
340        } else {
341            let (num_words_read, num_words_limbs) =
342                memory.read::<RV32_REGISTER_NUM_LIMBS>(d, num_words_ptr);
343            (compose(num_words_limbs), Some(num_words_read))
344        };
345        debug_assert_ne!(num_words, 0);
346        debug_assert!(num_words <= (1 << self.air.pointer_max_bits));
347
348        let mem_ptr = compose(mem_ptr_limbs);
349
350        debug_assert!(mem_ptr <= (1 << self.air.pointer_max_bits));
351
352        let mut streams = self.streams.get().unwrap().lock().unwrap();
353        if streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize {
354            return Err(ExecutionError::HintOutOfBounds { pc: from_state.pc });
355        }
356
357        let mut record = Rv32HintStoreRecord {
358            from_state,
359            instruction: instruction.clone(),
360            mem_ptr_read,
361            mem_ptr,
362            num_words,
363            num_words_read,
364            hints: vec![],
365        };
366
367        for word_index in 0..num_words {
368            if word_index != 0 {
369                memory.increment_timestamp();
370                memory.increment_timestamp();
371            }
372
373            let data: [F; RV32_REGISTER_NUM_LIMBS] =
374                std::array::from_fn(|_| streams.hint_stream.pop_front().unwrap());
375            let (write, _) = memory.write(
376                e,
377                F::from_canonical_u32(mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index)),
378                data,
379            );
380            record.hints.push((data, write));
381        }
382
383        self.height += record.hints.len();
384        self.records.push(record);
385
386        let next_state = ExecutionState {
387            pc: from_state.pc + DEFAULT_PC_STEP,
388            timestamp: memory.timestamp(),
389        };
390        Ok(next_state)
391    }
392
393    fn get_opcode_name(&self, opcode: usize) -> String {
394        if opcode == HINT_STOREW.global_opcode().as_usize() {
395            String::from("HINT_STOREW")
396        } else if opcode == HINT_BUFFER.global_opcode().as_usize() {
397            String::from("HINT_BUFFER")
398        } else {
399            unreachable!("unsupported opcode: {}", opcode)
400        }
401    }
402}
403
404impl<F: Field> ChipUsageGetter for Rv32HintStoreChip<F> {
405    fn air_name(&self) -> String {
406        "Rv32HintStoreAir".to_string()
407    }
408
409    fn current_trace_height(&self) -> usize {
410        self.height
411    }
412
413    fn trace_width(&self) -> usize {
414        Rv32HintStoreCols::<F>::width()
415    }
416}
417
418impl<F: PrimeField32> Rv32HintStoreChip<F> {
419    // returns number of used u32s
420    fn record_to_rows(
421        record: Rv32HintStoreRecord<F>,
422        aux_cols_factory: &MemoryAuxColsFactory<F>,
423        slice: &mut [F],
424        memory: &OfflineMemory<F>,
425        bitwise_lookup_chip: &SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
426        pointer_max_bits: usize,
427    ) -> usize {
428        let width = Rv32HintStoreCols::<F>::width();
429        let cols: &mut Rv32HintStoreCols<F> = slice[..width].borrow_mut();
430
431        cols.is_single = F::from_bool(record.num_words_read.is_none());
432        cols.is_buffer = F::from_bool(record.num_words_read.is_some());
433        cols.is_buffer_start = cols.is_buffer;
434
435        cols.from_state = record.from_state.map(F::from_canonical_u32);
436        cols.mem_ptr_ptr = record.instruction.b;
437        aux_cols_factory.generate_read_aux(
438            memory.record_by_id(record.mem_ptr_read),
439            &mut cols.mem_ptr_aux_cols,
440        );
441
442        cols.num_words_ptr = record.instruction.a;
443        if let Some(num_words_read) = record.num_words_read {
444            aux_cols_factory.generate_read_aux(
445                memory.record_by_id(num_words_read),
446                &mut cols.num_words_aux_cols,
447            );
448        }
449
450        let mut mem_ptr = record.mem_ptr;
451        let mut rem_words = record.num_words;
452        let mut used_u32s = 0;
453
454        let mem_ptr_msl = mem_ptr >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS);
455        let rem_words_msl = rem_words >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS);
456        bitwise_lookup_chip.request_range(
457            mem_ptr_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits),
458            rem_words_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits),
459        );
460        for (i, &(data, write)) in record.hints.iter().enumerate() {
461            for half in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
462                bitwise_lookup_chip.request_range(
463                    data[2 * half].as_canonical_u32(),
464                    data[2 * half + 1].as_canonical_u32(),
465                );
466            }
467
468            let cols: &mut Rv32HintStoreCols<F> = slice[used_u32s..used_u32s + width].borrow_mut();
469            cols.from_state.timestamp =
470                F::from_canonical_u32(record.from_state.timestamp + (3 * i as u32));
471            cols.data = data;
472            aux_cols_factory.generate_write_aux(memory.record_by_id(write), &mut cols.write_aux);
473            cols.rem_words_limbs = decompose(rem_words);
474            cols.mem_ptr_limbs = decompose(mem_ptr);
475            if i != 0 {
476                cols.is_buffer = F::ONE;
477            }
478            used_u32s += width;
479            mem_ptr += RV32_REGISTER_NUM_LIMBS as u32;
480            rem_words -= 1;
481        }
482
483        used_u32s
484    }
485
486    fn generate_trace(self) -> RowMajorMatrix<F> {
487        let width = self.trace_width();
488        let height = next_power_of_two_or_zero(self.height);
489        let mut flat_trace = F::zero_vec(width * height);
490
491        let memory = self.offline_memory.lock().unwrap();
492
493        let aux_cols_factory = memory.aux_cols_factory();
494
495        let mut used_u32s = 0;
496        for record in self.records {
497            used_u32s += Self::record_to_rows(
498                record,
499                &aux_cols_factory,
500                &mut flat_trace[used_u32s..],
501                &memory,
502                &self.bitwise_lookup_chip,
503                self.air.pointer_max_bits,
504            );
505        }
506        // padding rows can just be all zeros
507        RowMajorMatrix::new(flat_trace, width)
508    }
509}
510
511impl<SC: StarkGenericConfig> Chip<SC> for Rv32HintStoreChip<Val<SC>>
512where
513    Val<SC>: PrimeField32,
514{
515    fn air(&self) -> Arc<dyn AnyRap<SC>> {
516        Arc::new(self.air)
517    }
518    fn generate_air_proof_input(self) -> AirProofInput<SC> {
519        AirProofInput::simple_no_pis(self.generate_trace())
520    }
521}