openvm_rv32im_circuit/adapters/
loadstore.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9        ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface,
10    },
11    system::{
12        memory::{
13            offline_checker::{
14                MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord,
15                MemoryWriteAuxCols,
16            },
17            online::TracingMemory,
18            MemoryAddress, MemoryAuxColsFactory,
19        },
20        native_adapter::util::{memory_read_native, timed_write_native},
21    },
22};
23use openvm_circuit_primitives::{
24    utils::{not, select},
25    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
26    AlignedBytesBorrow,
27};
28use openvm_circuit_primitives_derive::AlignedBorrow;
29use openvm_instructions::{
30    instruction::Instruction,
31    program::DEFAULT_PC_STEP,
32    riscv::{RV32_IMM_AS, RV32_MEMORY_AS, RV32_REGISTER_AS},
33    LocalOpcode, NATIVE_AS,
34};
35use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
36use openvm_stark_backend::{
37    interaction::InteractionBuilder,
38    p3_air::{AirBuilder, BaseAir},
39    p3_field::{Field, FieldAlgebra, PrimeField32},
40};
41
42use super::RV32_REGISTER_NUM_LIMBS;
43use crate::adapters::{memory_read, timed_write, tracing_read, RV32_CELL_BITS};
44
45/// LoadStore Adapter handles all memory and register operations, so it must be aware
46/// of the instruction type, specifically whether it is a load or store
47/// LoadStore Adapter handles 4 byte aligned lw, sw instructions,
48///                           2 byte aligned lh, lhu, sh instructions and
49///                           1 byte aligned lb, lbu, sb instructions
50/// This adapter always batch reads/writes 4 bytes,
51/// thus it needs to shift left the memory pointer by some amount in case of not 4 byte aligned
52/// intermediate pointers
53pub struct LoadStoreInstruction<T> {
54    /// is_valid is constrained to be bool
55    pub is_valid: T,
56    /// Absolute opcode number
57    pub opcode: T,
58    /// is_load is constrained to be bool, and can only be 1 if is_valid is 1
59    pub is_load: T,
60
61    /// Keeping two separate shift amounts is needed for getting the read_ptr/write_ptr with degree
62    /// 2 load_shift_amount will be the shift amount if load and 0 if store
63    pub load_shift_amount: T,
64    /// store_shift_amount will be 0 if load and the shift amount if store
65    pub store_shift_amount: T,
66}
67
68pub struct Rv32LoadStoreAdapterAirInterface<AB: InteractionBuilder>(PhantomData<AB>);
69
70/// Using AB::Var for prev_data and AB::Expr for read_data
71impl<AB: InteractionBuilder> VmAdapterInterface<AB::Expr> for Rv32LoadStoreAdapterAirInterface<AB> {
72    type Reads = (
73        [AB::Var; RV32_REGISTER_NUM_LIMBS],
74        [AB::Expr; RV32_REGISTER_NUM_LIMBS],
75    );
76    type Writes = [[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1];
77    type ProcessedInstruction = LoadStoreInstruction<AB::Expr>;
78}
79
80#[repr(C)]
81#[derive(Debug, Clone, AlignedBorrow)]
82pub struct Rv32LoadStoreAdapterCols<T> {
83    pub from_state: ExecutionState<T>,
84    pub rs1_ptr: T,
85    pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
86    pub rs1_aux_cols: MemoryReadAuxCols<T>,
87
88    /// Will write to rd when Load and read from rs2 when Store
89    pub rd_rs2_ptr: T,
90    pub read_data_aux: MemoryReadAuxCols<T>,
91    pub imm: T,
92    pub imm_sign: T,
93    /// mem_ptr is the intermediate memory pointer limbs, needed to check the correct addition
94    pub mem_ptr_limbs: [T; 2],
95    pub mem_as: T,
96    /// prev_data will be provided by the core chip to make a complete MemoryWriteAuxCols
97    pub write_base_aux: MemoryBaseAuxCols<T>,
98    /// Only writes if `needs_write`.
99    /// If the instruction is a Load:
100    /// - Sets `needs_write` to 0 iff `rd == x0`
101    ///
102    /// Otherwise:
103    /// - Sets `needs_write` to 1
104    pub needs_write: T,
105}
106
107#[derive(Clone, Copy, Debug, derive_new::new)]
108pub struct Rv32LoadStoreAdapterAir {
109    pub(super) memory_bridge: MemoryBridge,
110    pub(super) execution_bridge: ExecutionBridge,
111    pub range_bus: VariableRangeCheckerBus,
112    pointer_max_bits: usize,
113}
114
115impl<F: Field> BaseAir<F> for Rv32LoadStoreAdapterAir {
116    fn width(&self) -> usize {
117        Rv32LoadStoreAdapterCols::<F>::width()
118    }
119}
120
121impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32LoadStoreAdapterAir {
122    type Interface = Rv32LoadStoreAdapterAirInterface<AB>;
123
124    fn eval(
125        &self,
126        builder: &mut AB,
127        local: &[AB::Var],
128        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
129    ) {
130        let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
131
132        let timestamp: AB::Var = local_cols.from_state.timestamp;
133        let mut timestamp_delta: usize = 0;
134        let mut timestamp_pp = || {
135            timestamp_delta += 1;
136            timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
137        };
138
139        let is_load = ctx.instruction.is_load;
140        let is_valid = ctx.instruction.is_valid;
141        let load_shift_amount = ctx.instruction.load_shift_amount;
142        let store_shift_amount = ctx.instruction.store_shift_amount;
143        let shift_amount = load_shift_amount.clone() + store_shift_amount.clone();
144
145        let write_count = local_cols.needs_write;
146
147        // This constraint ensures that the memory write only occurs when `is_valid == 1`.
148        builder.assert_bool(write_count);
149        builder.when(write_count).assert_one(is_valid.clone());
150
151        // Constrain that if `is_valid == 1` and `write_count == 0`, then `is_load == 1` and
152        // `rd_rs2_ptr == x0`
153        builder
154            .when(is_valid.clone() - write_count)
155            .assert_one(is_load.clone());
156        builder
157            .when(is_valid.clone() - write_count)
158            .assert_zero(local_cols.rd_rs2_ptr);
159
160        // read rs1
161        self.memory_bridge
162            .read(
163                MemoryAddress::new(
164                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
165                    local_cols.rs1_ptr,
166                ),
167                local_cols.rs1_data,
168                timestamp_pp(),
169                &local_cols.rs1_aux_cols,
170            )
171            .eval(builder, is_valid.clone());
172
173        // constrain mem_ptr = rs1 + imm as a u32 addition with 2 limbs
174        let limbs_01 = local_cols.rs1_data[0]
175            + local_cols.rs1_data[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
176        let limbs_23 = local_cols.rs1_data[2]
177            + local_cols.rs1_data[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
178
179        let inv = AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2)).inverse();
180        let carry = (limbs_01 + local_cols.imm - local_cols.mem_ptr_limbs[0]) * inv;
181
182        builder.when(is_valid.clone()).assert_bool(carry.clone());
183
184        builder
185            .when(is_valid.clone())
186            .assert_bool(local_cols.imm_sign);
187        let imm_extend_limb =
188            local_cols.imm_sign * AB::F::from_canonical_u32((1 << (RV32_CELL_BITS * 2)) - 1);
189        let carry = (limbs_23 + imm_extend_limb + carry - local_cols.mem_ptr_limbs[1]) * inv;
190        builder.when(is_valid.clone()).assert_bool(carry.clone());
191
192        // preventing mem_ptr overflow
193        self.range_bus
194            .range_check(
195                // (limb[0] - shift_amount) / 4 < 2^14 => limb[0] - shift_amount < 2^16
196                (local_cols.mem_ptr_limbs[0] - shift_amount)
197                    * AB::F::from_canonical_u32(4).inverse(),
198                RV32_CELL_BITS * 2 - 2,
199            )
200            .eval(builder, is_valid.clone());
201        self.range_bus
202            .range_check(
203                local_cols.mem_ptr_limbs[1],
204                self.pointer_max_bits - RV32_CELL_BITS * 2,
205            )
206            .eval(builder, is_valid.clone());
207
208        let mem_ptr = local_cols.mem_ptr_limbs[0]
209            + local_cols.mem_ptr_limbs[1] * AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2));
210
211        let is_store = is_valid.clone() - is_load.clone();
212        // constrain mem_as to be in {0, 1, 2} if the instruction is a load,
213        // and in {2, 3, 4} if the instruction is a store
214        builder.assert_tern(local_cols.mem_as - is_store * AB::Expr::TWO);
215        builder
216            .when(not::<AB::Expr>(is_valid.clone()))
217            .assert_zero(local_cols.mem_as);
218
219        // read_as is [local_cols.mem_as] for loads and 1 for stores
220        let read_as = select::<AB::Expr>(
221            is_load.clone(),
222            local_cols.mem_as,
223            AB::F::from_canonical_u32(RV32_REGISTER_AS),
224        );
225
226        // read_ptr is mem_ptr for loads and rd_rs2_ptr for stores
227        // Note: shift_amount is expected to have degree 2, thus we can't put it in the select
228        // clause       since the resulting read_ptr/write_ptr's degree will be 3 which is
229        // too high.       Instead, the solution without using additional columns is to get
230        // two different shift amounts from core chip
231        let read_ptr = select::<AB::Expr>(is_load.clone(), mem_ptr.clone(), local_cols.rd_rs2_ptr)
232            - load_shift_amount;
233
234        self.memory_bridge
235            .read(
236                MemoryAddress::new(read_as, read_ptr),
237                ctx.reads.1,
238                timestamp_pp(),
239                &local_cols.read_data_aux,
240            )
241            .eval(builder, is_valid.clone());
242
243        let write_aux_cols = MemoryWriteAuxCols::from_base(local_cols.write_base_aux, ctx.reads.0);
244
245        // write_as is 1 for loads and [local_cols.mem_as] for stores
246        let write_as = select::<AB::Expr>(
247            is_load.clone(),
248            AB::F::from_canonical_u32(RV32_REGISTER_AS),
249            local_cols.mem_as,
250        );
251
252        // write_ptr is rd_rs2_ptr for loads and mem_ptr for stores
253        let write_ptr = select::<AB::Expr>(is_load.clone(), local_cols.rd_rs2_ptr, mem_ptr.clone())
254            - store_shift_amount;
255
256        self.memory_bridge
257            .write(
258                MemoryAddress::new(write_as, write_ptr),
259                ctx.writes[0].clone(),
260                timestamp_pp(),
261                &write_aux_cols,
262            )
263            .eval(builder, write_count);
264
265        let to_pc = ctx
266            .to_pc
267            .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
268        self.execution_bridge
269            .execute(
270                ctx.instruction.opcode,
271                [
272                    local_cols.rd_rs2_ptr.into(),
273                    local_cols.rs1_ptr.into(),
274                    local_cols.imm.into(),
275                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
276                    local_cols.mem_as.into(),
277                    local_cols.needs_write.into(),
278                    local_cols.imm_sign.into(),
279                ],
280                local_cols.from_state,
281                ExecutionState {
282                    pc: to_pc,
283                    timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
284                },
285            )
286            .eval(builder, is_valid);
287    }
288
289    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
290        let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
291        local_cols.from_state.pc
292    }
293}
294
295#[repr(C)]
296#[derive(AlignedBytesBorrow, Debug)]
297pub struct Rv32LoadStoreAdapterRecord {
298    pub from_pc: u32,
299    pub from_timestamp: u32,
300
301    pub rs1_ptr: u32,
302    pub rs1_val: u32,
303    pub rs1_aux_record: MemoryReadAuxRecord,
304
305    pub rd_rs2_ptr: u32,
306    pub read_data_aux: MemoryReadAuxRecord,
307    pub imm: u16,
308    pub imm_sign: bool,
309
310    pub mem_as: u8,
311
312    pub write_prev_timestamp: u32,
313}
314
315/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm.
316/// In case of Loads, reads from the shifted intermediate pointer and writes to rd.
317/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer.
318#[derive(Clone, Copy, derive_new::new)]
319pub struct Rv32LoadStoreAdapterExecutor {
320    pointer_max_bits: usize,
321}
322
323#[derive(derive_new::new)]
324pub struct Rv32LoadStoreAdapterFiller {
325    pointer_max_bits: usize,
326    pub range_checker_chip: SharedVariableRangeCheckerChip,
327}
328
329impl<F> AdapterTraceExecutor<F> for Rv32LoadStoreAdapterExecutor
330where
331    F: PrimeField32,
332{
333    const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
334    type ReadData = (
335        (
336            [u32; RV32_REGISTER_NUM_LIMBS],
337            [u8; RV32_REGISTER_NUM_LIMBS],
338        ),
339        u8,
340    );
341    type WriteData = [u32; RV32_REGISTER_NUM_LIMBS];
342    type RecordMut<'a> = &'a mut Rv32LoadStoreAdapterRecord;
343
344    #[inline(always)]
345    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
346        record.from_pc = pc;
347        record.from_timestamp = memory.timestamp;
348    }
349
350    #[inline(always)]
351    fn read(
352        &self,
353        memory: &mut TracingMemory,
354        instruction: &Instruction<F>,
355        record: &mut Self::RecordMut<'_>,
356    ) -> Self::ReadData {
357        let &Instruction {
358            opcode,
359            a,
360            b,
361            c,
362            d,
363            e,
364            g,
365            ..
366        } = instruction;
367
368        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
369
370        let local_opcode = Rv32LoadStoreOpcode::from_usize(
371            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
372        );
373
374        record.rs1_ptr = b.as_canonical_u32();
375        record.rs1_val = u32::from_le_bytes(tracing_read(
376            memory,
377            RV32_REGISTER_AS,
378            record.rs1_ptr,
379            &mut record.rs1_aux_record.prev_timestamp,
380        ));
381
382        record.imm = c.as_canonical_u32() as u16;
383        record.imm_sign = g.is_one();
384        let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
385
386        let ptr_val = record.rs1_val.wrapping_add(imm_extended);
387        let shift_amount = ptr_val & 3;
388        let ptr_val = ptr_val - shift_amount;
389
390        assert!(
391            ptr_val < (1 << self.pointer_max_bits),
392            "ptr_val: {ptr_val} = rs1_val: {} + imm_extended: {imm_extended} >= 2 ** {}",
393            record.rs1_val,
394            self.pointer_max_bits
395        );
396
397        // prev_data: We need to keep values of some cells to keep them unchanged when writing to
398        // those cells
399        let (read_data, prev_data) = match local_opcode {
400            LOADW | LOADB | LOADH | LOADBU | LOADHU => {
401                debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS));
402                record.mem_as = RV32_MEMORY_AS as u8;
403                let read_data = tracing_read(
404                    memory,
405                    RV32_MEMORY_AS,
406                    ptr_val,
407                    &mut record.read_data_aux.prev_timestamp,
408                );
409                let prev_data = memory_read(memory.data(), RV32_REGISTER_AS, a.as_canonical_u32())
410                    .map(u32::from);
411                (read_data, prev_data)
412            }
413            STOREW | STOREH | STOREB => {
414                let e = e.as_canonical_u32();
415                debug_assert_ne!(e, RV32_IMM_AS);
416                debug_assert_ne!(e, RV32_REGISTER_AS);
417                record.mem_as = e as u8;
418                let read_data = tracing_read(
419                    memory,
420                    RV32_REGISTER_AS,
421                    a.as_canonical_u32(),
422                    &mut record.read_data_aux.prev_timestamp,
423                );
424                let prev_data = if e == NATIVE_AS {
425                    memory_read_native(memory.data(), ptr_val).map(|x: F| x.as_canonical_u32())
426                } else {
427                    memory_read(memory.data(), e, ptr_val).map(u32::from)
428                };
429                (read_data, prev_data)
430            }
431        };
432
433        ((prev_data, read_data), shift_amount as u8)
434    }
435
436    #[inline(always)]
437    fn write(
438        &self,
439        memory: &mut TracingMemory,
440        instruction: &Instruction<F>,
441        data: Self::WriteData,
442        record: &mut Self::RecordMut<'_>,
443    ) {
444        let &Instruction {
445            opcode,
446            a,
447            d,
448            e,
449            f: enabled,
450            ..
451        } = instruction;
452
453        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
454        debug_assert_ne!(e.as_canonical_u32(), RV32_IMM_AS);
455        debug_assert_ne!(e.as_canonical_u32(), RV32_REGISTER_AS);
456
457        let local_opcode = Rv32LoadStoreOpcode::from_usize(
458            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
459        );
460
461        if enabled != F::ZERO {
462            record.rd_rs2_ptr = a.as_canonical_u32();
463
464            record.write_prev_timestamp = match local_opcode {
465                STOREW | STOREH | STOREB => {
466                    let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
467                    let ptr = record.rs1_val.wrapping_add(imm_extended) & !3;
468
469                    if record.mem_as == 4 {
470                        timed_write_native(memory, ptr, data.map(F::from_canonical_u32)).0
471                    } else {
472                        timed_write(memory, record.mem_as as u32, ptr, data.map(|x| x as u8)).0
473                    }
474                }
475                LOADW | LOADB | LOADH | LOADBU | LOADHU => {
476                    timed_write(
477                        memory,
478                        RV32_REGISTER_AS,
479                        record.rd_rs2_ptr,
480                        data.map(|x| x as u8),
481                    )
482                    .0
483                }
484            };
485        } else {
486            record.rd_rs2_ptr = u32::MAX;
487            memory.increment_timestamp();
488        };
489    }
490}
491
492impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32LoadStoreAdapterFiller {
493    const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
494
495    #[inline(always)]
496    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
497        debug_assert!(self.range_checker_chip.range_max_bits() >= 15);
498
499        // SAFETY:
500        // - caller ensures `adapter_row` contains a valid record representation that was previously
501        //   written by the executor
502        // - get_record_from_slice correctly interprets the bytes as Rv32LoadStoreAdapterRecord
503        let record: &Rv32LoadStoreAdapterRecord =
504            unsafe { get_record_from_slice(&mut adapter_row, ()) };
505        let adapter_row: &mut Rv32LoadStoreAdapterCols<F> = adapter_row.borrow_mut();
506
507        let needs_write = record.rd_rs2_ptr != u32::MAX;
508        // Writing in reverse order
509        adapter_row.needs_write = F::from_bool(needs_write);
510
511        if needs_write {
512            mem_helper.fill(
513                record.write_prev_timestamp,
514                record.from_timestamp + 2,
515                &mut adapter_row.write_base_aux,
516            );
517        } else {
518            mem_helper.fill_zero(&mut adapter_row.write_base_aux);
519        }
520
521        adapter_row.mem_as = F::from_canonical_u8(record.mem_as);
522        let ptr = record
523            .rs1_val
524            .wrapping_add(record.imm as u32 + record.imm_sign as u32 * 0xffff0000);
525
526        let ptr_limbs = [ptr & 0xffff, ptr >> 16];
527        self.range_checker_chip
528            .add_count(ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2);
529        self.range_checker_chip
530            .add_count(ptr_limbs[1], self.pointer_max_bits - 16);
531        adapter_row.mem_ptr_limbs = ptr_limbs.map(F::from_canonical_u32);
532
533        adapter_row.imm_sign = F::from_bool(record.imm_sign);
534        adapter_row.imm = F::from_canonical_u16(record.imm);
535
536        mem_helper.fill(
537            record.read_data_aux.prev_timestamp,
538            record.from_timestamp + 1,
539            adapter_row.read_data_aux.as_mut(),
540        );
541        adapter_row.rd_rs2_ptr = if record.rd_rs2_ptr != u32::MAX {
542            F::from_canonical_u32(record.rd_rs2_ptr)
543        } else {
544            F::ZERO
545        };
546
547        mem_helper.fill(
548            record.rs1_aux_record.prev_timestamp,
549            record.from_timestamp,
550            adapter_row.rs1_aux_cols.as_mut(),
551        );
552
553        adapter_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8);
554        adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
555
556        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
557        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
558    }
559}