openvm_rv32im_circuit/adapters/
loadstore.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    marker::PhantomData,
5};
6
7use openvm_circuit::{
8    arch::{
9        AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState,
10        Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface,
11    },
12    system::{
13        memory::{
14            offline_checker::{
15                MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols,
16            },
17            MemoryAddress, MemoryController, OfflineMemory, RecordId,
18        },
19        program::ProgramBus,
20    },
21};
22use openvm_circuit_primitives::{
23    utils::{not, select},
24    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28    instruction::Instruction,
29    program::DEFAULT_PC_STEP,
30    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
31    LocalOpcode,
32};
33use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
34use openvm_stark_backend::{
35    interaction::InteractionBuilder,
36    p3_air::{AirBuilder, BaseAir},
37    p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40
41use super::{compose, RV32_REGISTER_NUM_LIMBS};
42use crate::adapters::RV32_CELL_BITS;
43
44/// LoadStore Adapter handles all memory and register operations, so it must be aware
45/// of the instruction type, specifically whether it is a load or store
46/// LoadStore Adapter handles 4 byte aligned lw, sw instructions,
47///                           2 byte aligned lh, lhu, sh instructions and
48///                           1 byte aligned lb, lbu, sb instructions
49/// This adapter always batch reads/writes 4 bytes,
50/// thus it needs to shift left the memory pointer by some amount in case of not 4 byte aligned intermediate pointers
51pub struct LoadStoreInstruction<T> {
52    /// is_valid is constrained to be bool
53    pub is_valid: T,
54    /// Absolute opcode number
55    pub opcode: T,
56    /// is_load is constrained to be bool, and can only be 1 if is_valid is 1
57    pub is_load: T,
58
59    /// Keeping two separate shift amounts is needed for getting the read_ptr/write_ptr with degree 2
60    /// load_shift_amount will be the shift amount if load and 0 if store
61    pub load_shift_amount: T,
62    /// store_shift_amount will be 0 if load and the shift amount if store
63    pub store_shift_amount: T,
64}
65
66/// The LoadStoreAdapter separates Runtime and Air AdapterInterfaces.
67/// This is necessary because `prev_data` should be owned by the core chip and sent to the adapter,
68/// and it must have an AB::Var type in AIR as to satisfy the memory_bridge interface.
69/// This is achieved by having different types for reads and writes in Air AdapterInterface.
70/// This method ensures that there are no modifications to the global interfaces.
71///
72/// Here 2 reads represent read_data and prev_data,
73/// The second element of the tuple in Reads is the shift amount needed to be passed to the core chip
74/// Getting the intermediate pointer is completely internal to the adapter and shouldn't be a part of the AdapterInterface
75pub struct Rv32LoadStoreAdapterRuntimeInterface<T>(PhantomData<T>);
76impl<T> VmAdapterInterface<T> for Rv32LoadStoreAdapterRuntimeInterface<T> {
77    type Reads = ([[T; RV32_REGISTER_NUM_LIMBS]; 2], T);
78    type Writes = [[T; RV32_REGISTER_NUM_LIMBS]; 1];
79    type ProcessedInstruction = ();
80}
81pub struct Rv32LoadStoreAdapterAirInterface<AB: InteractionBuilder>(PhantomData<AB>);
82
83/// Using AB::Var for prev_data and AB::Expr for read_data
84impl<AB: InteractionBuilder> VmAdapterInterface<AB::Expr> for Rv32LoadStoreAdapterAirInterface<AB> {
85    type Reads = (
86        [AB::Var; RV32_REGISTER_NUM_LIMBS],
87        [AB::Expr; RV32_REGISTER_NUM_LIMBS],
88    );
89    type Writes = [[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1];
90    type ProcessedInstruction = LoadStoreInstruction<AB::Expr>;
91}
92
93/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm.
94/// In case of Loads, reads from the shifted intermediate pointer and writes to rd.
95/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer.
96pub struct Rv32LoadStoreAdapterChip<F: Field> {
97    pub air: Rv32LoadStoreAdapterAir,
98    pub range_checker_chip: SharedVariableRangeCheckerChip,
99    _marker: PhantomData<F>,
100}
101
102impl<F: PrimeField32> Rv32LoadStoreAdapterChip<F> {
103    pub fn new(
104        execution_bus: ExecutionBus,
105        program_bus: ProgramBus,
106        memory_bridge: MemoryBridge,
107        pointer_max_bits: usize,
108        range_checker_chip: SharedVariableRangeCheckerChip,
109    ) -> Self {
110        assert!(range_checker_chip.range_max_bits() >= 15);
111        Self {
112            air: Rv32LoadStoreAdapterAir {
113                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
114                memory_bridge,
115                range_bus: range_checker_chip.bus(),
116                pointer_max_bits,
117            },
118            range_checker_chip,
119            _marker: PhantomData,
120        }
121    }
122}
123
124#[repr(C)]
125#[derive(Debug, Clone, Serialize, Deserialize)]
126#[serde(bound = "F: Field")]
127pub struct Rv32LoadStoreReadRecord<F: Field> {
128    pub rs1_record: RecordId,
129    /// This will be a read from a register in case of Stores and a read from RISC-V memory in case of Loads.
130    pub read: RecordId,
131    pub rs1_ptr: F,
132    pub imm: F,
133    pub imm_sign: F,
134    pub mem_as: F,
135    pub mem_ptr_limbs: [u32; 2],
136    pub shift_amount: u32,
137}
138
139#[repr(C)]
140#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(bound = "F: Field")]
142pub struct Rv32LoadStoreWriteRecord<F: Field> {
143    /// This will be a write to a register in case of Load and a write to RISC-V memory in case of Stores.
144    /// For better struct packing, `RecordId(usize::MAX)` is used to indicate that there is no write.
145    pub write_id: RecordId,
146    pub from_state: ExecutionState<u32>,
147    pub rd_rs2_ptr: F,
148}
149
150#[repr(C)]
151#[derive(Debug, Clone, AlignedBorrow)]
152pub struct Rv32LoadStoreAdapterCols<T> {
153    pub from_state: ExecutionState<T>,
154    pub rs1_ptr: T,
155    pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
156    pub rs1_aux_cols: MemoryReadAuxCols<T>,
157
158    /// Will write to rd when Load and read from rs2 when Store
159    pub rd_rs2_ptr: T,
160    pub read_data_aux: MemoryReadAuxCols<T>,
161    pub imm: T,
162    pub imm_sign: T,
163    /// mem_ptr is the intermediate memory pointer limbs, needed to check the correct addition
164    pub mem_ptr_limbs: [T; 2],
165    pub mem_as: T,
166    /// prev_data will be provided by the core chip to make a complete MemoryWriteAuxCols
167    pub write_base_aux: MemoryBaseAuxCols<T>,
168    /// Only writes if `needs_write`.
169    /// If the instruction is a Load:
170    /// - Sets `needs_write` to 0 iff `rd == x0`
171    ///
172    /// Otherwise:
173    /// - Sets `needs_write` to 1
174    pub needs_write: T,
175}
176
177#[derive(Clone, Copy, Debug, derive_new::new)]
178pub struct Rv32LoadStoreAdapterAir {
179    pub(super) memory_bridge: MemoryBridge,
180    pub(super) execution_bridge: ExecutionBridge,
181    pub range_bus: VariableRangeCheckerBus,
182    pointer_max_bits: usize,
183}
184
185impl<F: Field> BaseAir<F> for Rv32LoadStoreAdapterAir {
186    fn width(&self) -> usize {
187        Rv32LoadStoreAdapterCols::<F>::width()
188    }
189}
190
191impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32LoadStoreAdapterAir {
192    type Interface = Rv32LoadStoreAdapterAirInterface<AB>;
193
194    fn eval(
195        &self,
196        builder: &mut AB,
197        local: &[AB::Var],
198        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
199    ) {
200        let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
201
202        let timestamp: AB::Var = local_cols.from_state.timestamp;
203        let mut timestamp_delta: usize = 0;
204        let mut timestamp_pp = || {
205            timestamp_delta += 1;
206            timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
207        };
208
209        let is_load = ctx.instruction.is_load;
210        let is_valid = ctx.instruction.is_valid;
211        let load_shift_amount = ctx.instruction.load_shift_amount;
212        let store_shift_amount = ctx.instruction.store_shift_amount;
213        let shift_amount = load_shift_amount.clone() + store_shift_amount.clone();
214
215        let write_count = local_cols.needs_write;
216
217        // This constraint ensures that the memory write only occurs when `is_valid == 1`.
218        builder.assert_bool(write_count);
219        builder.when(write_count).assert_one(is_valid.clone());
220
221        // Constrain that if `is_valid == 1` and `write_count == 0`, then `is_load == 1` and `rd_rs2_ptr == x0`
222        builder
223            .when(is_valid.clone() - write_count)
224            .assert_one(is_load.clone());
225        builder
226            .when(is_valid.clone() - write_count)
227            .assert_zero(local_cols.rd_rs2_ptr);
228
229        // read rs1
230        self.memory_bridge
231            .read(
232                MemoryAddress::new(
233                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
234                    local_cols.rs1_ptr,
235                ),
236                local_cols.rs1_data,
237                timestamp_pp(),
238                &local_cols.rs1_aux_cols,
239            )
240            .eval(builder, is_valid.clone());
241
242        // constrain mem_ptr = rs1 + imm as a u32 addition with 2 limbs
243        let limbs_01 = local_cols.rs1_data[0]
244            + local_cols.rs1_data[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
245        let limbs_23 = local_cols.rs1_data[2]
246            + local_cols.rs1_data[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
247
248        let inv = AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2)).inverse();
249        let carry = (limbs_01 + local_cols.imm - local_cols.mem_ptr_limbs[0]) * inv;
250
251        builder.when(is_valid.clone()).assert_bool(carry.clone());
252
253        builder
254            .when(is_valid.clone())
255            .assert_bool(local_cols.imm_sign);
256        let imm_extend_limb =
257            local_cols.imm_sign * AB::F::from_canonical_u32((1 << (RV32_CELL_BITS * 2)) - 1);
258        let carry = (limbs_23 + imm_extend_limb + carry - local_cols.mem_ptr_limbs[1]) * inv;
259        builder.when(is_valid.clone()).assert_bool(carry.clone());
260
261        // preventing mem_ptr overflow
262        self.range_bus
263            .range_check(
264                // (limb[0] - shift_amount) / 4 < 2^14 => limb[0] - shift_amount < 2^16
265                (local_cols.mem_ptr_limbs[0] - shift_amount)
266                    * AB::F::from_canonical_u32(4).inverse(),
267                RV32_CELL_BITS * 2 - 2,
268            )
269            .eval(builder, is_valid.clone());
270        self.range_bus
271            .range_check(
272                local_cols.mem_ptr_limbs[1],
273                self.pointer_max_bits - RV32_CELL_BITS * 2,
274            )
275            .eval(builder, is_valid.clone());
276
277        let mem_ptr = local_cols.mem_ptr_limbs[0]
278            + local_cols.mem_ptr_limbs[1] * AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2));
279
280        let is_store = is_valid.clone() - is_load.clone();
281        // constrain mem_as to be in {0, 1, 2} if the instruction is a load,
282        // and in {2, 3, 4} if the instruction is a store
283        builder.assert_tern(local_cols.mem_as - is_store * AB::Expr::TWO);
284        builder
285            .when(not::<AB::Expr>(is_valid.clone()))
286            .assert_zero(local_cols.mem_as);
287
288        // read_as is [local_cols.mem_as] for loads and 1 for stores
289        let read_as = select::<AB::Expr>(
290            is_load.clone(),
291            local_cols.mem_as,
292            AB::F::from_canonical_u32(RV32_REGISTER_AS),
293        );
294
295        // read_ptr is mem_ptr for loads and rd_rs2_ptr for stores
296        // Note: shift_amount is expected to have degree 2, thus we can't put it in the select clause
297        //       since the resulting read_ptr/write_ptr's degree will be 3 which is too high.
298        //       Instead, the solution without using additional columns is to get two different shift amounts from core chip
299        let read_ptr = select::<AB::Expr>(is_load.clone(), mem_ptr.clone(), local_cols.rd_rs2_ptr)
300            - load_shift_amount;
301
302        self.memory_bridge
303            .read(
304                MemoryAddress::new(read_as, read_ptr),
305                ctx.reads.1,
306                timestamp_pp(),
307                &local_cols.read_data_aux,
308            )
309            .eval(builder, is_valid.clone());
310
311        let write_aux_cols = MemoryWriteAuxCols::from_base(local_cols.write_base_aux, ctx.reads.0);
312
313        // write_as is 1 for loads and [local_cols.mem_as] for stores
314        let write_as = select::<AB::Expr>(
315            is_load.clone(),
316            AB::F::from_canonical_u32(RV32_REGISTER_AS),
317            local_cols.mem_as,
318        );
319
320        // write_ptr is rd_rs2_ptr for loads and mem_ptr for stores
321        let write_ptr = select::<AB::Expr>(is_load.clone(), local_cols.rd_rs2_ptr, mem_ptr.clone())
322            - store_shift_amount;
323
324        self.memory_bridge
325            .write(
326                MemoryAddress::new(write_as, write_ptr),
327                ctx.writes[0].clone(),
328                timestamp_pp(),
329                &write_aux_cols,
330            )
331            .eval(builder, write_count);
332
333        let to_pc = ctx
334            .to_pc
335            .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
336        self.execution_bridge
337            .execute(
338                ctx.instruction.opcode,
339                [
340                    local_cols.rd_rs2_ptr.into(),
341                    local_cols.rs1_ptr.into(),
342                    local_cols.imm.into(),
343                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
344                    local_cols.mem_as.into(),
345                    local_cols.needs_write.into(),
346                    local_cols.imm_sign.into(),
347                ],
348                local_cols.from_state,
349                ExecutionState {
350                    pc: to_pc,
351                    timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
352                },
353            )
354            .eval(builder, is_valid);
355    }
356
357    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
358        let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
359        local_cols.from_state.pc
360    }
361}
362
363impl<F: PrimeField32> VmAdapterChip<F> for Rv32LoadStoreAdapterChip<F> {
364    type ReadRecord = Rv32LoadStoreReadRecord<F>;
365    type WriteRecord = Rv32LoadStoreWriteRecord<F>;
366    type Air = Rv32LoadStoreAdapterAir;
367    type Interface = Rv32LoadStoreAdapterRuntimeInterface<F>;
368
369    #[allow(clippy::type_complexity)]
370    fn preprocess(
371        &mut self,
372        memory: &mut MemoryController<F>,
373        instruction: &Instruction<F>,
374    ) -> Result<(
375        <Self::Interface as VmAdapterInterface<F>>::Reads,
376        Self::ReadRecord,
377    )> {
378        let Instruction {
379            opcode,
380            a,
381            b,
382            c,
383            d,
384            e,
385            g,
386            ..
387        } = *instruction;
388        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
389        debug_assert!(e.as_canonical_u32() != RV32_IMM_AS);
390
391        let local_opcode = Rv32LoadStoreOpcode::from_usize(
392            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
393        );
394        let rs1_record = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
395
396        let rs1_val = compose(rs1_record.1);
397        let imm = c.as_canonical_u32();
398        let imm_sign = g.as_canonical_u32();
399        let imm_extended = imm + imm_sign * 0xffff0000;
400
401        let ptr_val = rs1_val.wrapping_add(imm_extended);
402        let shift_amount = ptr_val % 4;
403        assert!(
404            ptr_val < (1 << self.air.pointer_max_bits),
405            "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}",
406            self.air.pointer_max_bits
407        );
408
409        let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff));
410
411        let ptr_val = ptr_val - shift_amount;
412        let read_record = match local_opcode {
413            LOADW | LOADB | LOADH | LOADBU | LOADHU => {
414                memory.read::<RV32_REGISTER_NUM_LIMBS>(e, F::from_canonical_u32(ptr_val))
415            }
416            STOREW | STOREH | STOREB => memory.read::<RV32_REGISTER_NUM_LIMBS>(d, a),
417        };
418
419        // We need to keep values of some cells to keep them unchanged when writing to those cells
420        let prev_data = match local_opcode {
421            STOREW | STOREH | STOREB => array::from_fn(|i| {
422                memory.unsafe_read_cell(e, F::from_canonical_usize(ptr_val as usize + i))
423            }),
424            LOADW | LOADB | LOADH | LOADBU | LOADHU => {
425                array::from_fn(|i| memory.unsafe_read_cell(d, a + F::from_canonical_usize(i)))
426            }
427        };
428
429        Ok((
430            (
431                [prev_data, read_record.1],
432                F::from_canonical_u32(shift_amount),
433            ),
434            Self::ReadRecord {
435                rs1_record: rs1_record.0,
436                rs1_ptr: b,
437                read: read_record.0,
438                imm: c,
439                imm_sign: g,
440                shift_amount,
441                mem_ptr_limbs,
442                mem_as: e,
443            },
444        ))
445    }
446
447    fn postprocess(
448        &mut self,
449        memory: &mut MemoryController<F>,
450        instruction: &Instruction<F>,
451        from_state: ExecutionState<u32>,
452        output: AdapterRuntimeContext<F, Self::Interface>,
453        read_record: &Self::ReadRecord,
454    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
455        let Instruction {
456            opcode,
457            a,
458            d,
459            e,
460            f: enabled,
461            ..
462        } = *instruction;
463
464        let local_opcode = Rv32LoadStoreOpcode::from_usize(
465            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
466        );
467
468        let write_id = if enabled != F::ZERO {
469            let (record_id, _) = match local_opcode {
470                STOREW | STOREH | STOREB => {
471                    let ptr = read_record.mem_ptr_limbs[0]
472                        + read_record.mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2));
473                    memory.write(e, F::from_canonical_u32(ptr & 0xfffffffc), output.writes[0])
474                }
475                LOADW | LOADB | LOADH | LOADBU | LOADHU => memory.write(d, a, output.writes[0]),
476            };
477            record_id
478        } else {
479            memory.increment_timestamp();
480            // RecordId will never get to usize::MAX, so it can be used as a flag for no write
481            RecordId(usize::MAX)
482        };
483
484        Ok((
485            ExecutionState {
486                pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
487                timestamp: memory.timestamp(),
488            },
489            Self::WriteRecord {
490                from_state,
491                write_id,
492                rd_rs2_ptr: a,
493            },
494        ))
495    }
496
497    fn generate_trace_row(
498        &self,
499        row_slice: &mut [F],
500        read_record: Self::ReadRecord,
501        write_record: Self::WriteRecord,
502        memory: &OfflineMemory<F>,
503    ) {
504        self.range_checker_chip.add_count(
505            (read_record.mem_ptr_limbs[0] - read_record.shift_amount) / 4,
506            RV32_CELL_BITS * 2 - 2,
507        );
508        self.range_checker_chip.add_count(
509            read_record.mem_ptr_limbs[1],
510            self.air.pointer_max_bits - RV32_CELL_BITS * 2,
511        );
512
513        let aux_cols_factory = memory.aux_cols_factory();
514        let adapter_cols: &mut Rv32LoadStoreAdapterCols<_> = row_slice.borrow_mut();
515        adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
516        let rs1 = memory.record_by_id(read_record.rs1_record);
517        adapter_cols.rs1_data.copy_from_slice(rs1.data_slice());
518        aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols);
519        adapter_cols.rs1_ptr = read_record.rs1_ptr;
520        adapter_cols.rd_rs2_ptr = write_record.rd_rs2_ptr;
521        let read = memory.record_by_id(read_record.read);
522        aux_cols_factory.generate_read_aux(read, &mut adapter_cols.read_data_aux);
523        adapter_cols.imm = read_record.imm;
524        adapter_cols.imm_sign = read_record.imm_sign;
525        adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32);
526        adapter_cols.mem_as = read_record.mem_as;
527        if write_record.write_id.0 != usize::MAX {
528            let write = memory.record_by_id(write_record.write_id);
529            aux_cols_factory.generate_base_aux(write, &mut adapter_cols.write_base_aux);
530            adapter_cols.needs_write = F::ONE;
531        }
532    }
533
534    fn air(&self) -> &Self::Air {
535        &self.air
536    }
537}