openvm_rv32im_circuit/adapters/
loadstore.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9        ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface,
10    },
11    system::{
12        memory::{
13            offline_checker::{
14                MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord,
15                MemoryWriteAuxCols,
16            },
17            online::TracingMemory,
18            MemoryAddress, MemoryAuxColsFactory,
19        },
20        native_adapter::util::{memory_read_native, timed_write_native},
21    },
22};
23use openvm_circuit_primitives::{
24    utils::{not, select},
25    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
26    AlignedBytesBorrow,
27};
28use openvm_circuit_primitives_derive::AlignedBorrow;
29use openvm_instructions::{
30    instruction::Instruction,
31    program::DEFAULT_PC_STEP,
32    riscv::{RV32_IMM_AS, RV32_MEMORY_AS, RV32_REGISTER_AS},
33    LocalOpcode, NATIVE_AS,
34};
35use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
36use openvm_stark_backend::{
37    interaction::InteractionBuilder,
38    p3_air::{AirBuilder, BaseAir},
39    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
40};
41
42use super::RV32_REGISTER_NUM_LIMBS;
43use crate::adapters::{memory_read, timed_write, tracing_read, RV32_CELL_BITS};
44
45/// LoadStore Adapter handles all memory and register operations, so it must be aware
46/// of the instruction type, specifically whether it is a load or store
47/// LoadStore Adapter handles 4 byte aligned lw, sw instructions,
48///                           2 byte aligned lh, lhu, sh instructions and
49///                           1 byte aligned lb, lbu, sb instructions
50/// This adapter always batch reads/writes 4 bytes,
51/// thus it needs to shift left the memory pointer by some amount in case of not 4 byte aligned
52/// intermediate pointers
53pub struct LoadStoreInstruction<T> {
54    /// is_valid is constrained to be bool
55    pub is_valid: T,
56    /// Absolute opcode number
57    pub opcode: T,
58    /// is_load is constrained to be bool, and can only be 1 if is_valid is 1
59    pub is_load: T,
60
61    /// Keeping two separate shift amounts is needed for getting the read_ptr/write_ptr with degree
62    /// 2 load_shift_amount will be the shift amount if load and 0 if store
63    pub load_shift_amount: T,
64    /// store_shift_amount will be 0 if load and the shift amount if store
65    pub store_shift_amount: T,
66}
67
68pub struct Rv32LoadStoreAdapterAirInterface<AB: InteractionBuilder>(PhantomData<AB>);
69
70/// Using AB::Var for prev_data and AB::Expr for read_data
71impl<AB: InteractionBuilder> VmAdapterInterface<AB::Expr> for Rv32LoadStoreAdapterAirInterface<AB> {
72    type Reads = (
73        [AB::Var; RV32_REGISTER_NUM_LIMBS],
74        [AB::Expr; RV32_REGISTER_NUM_LIMBS],
75    );
76    type Writes = [[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1];
77    type ProcessedInstruction = LoadStoreInstruction<AB::Expr>;
78}
79
80#[repr(C)]
81#[derive(Debug, Clone, AlignedBorrow)]
82pub struct Rv32LoadStoreAdapterCols<T> {
83    pub from_state: ExecutionState<T>,
84    pub rs1_ptr: T,
85    pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
86    pub rs1_aux_cols: MemoryReadAuxCols<T>,
87
88    /// Will write to rd when Load and read from rs2 when Store
89    pub rd_rs2_ptr: T,
90    pub read_data_aux: MemoryReadAuxCols<T>,
91    pub imm: T,
92    pub imm_sign: T,
93    /// mem_ptr is the intermediate memory pointer limbs, needed to check the correct addition
94    pub mem_ptr_limbs: [T; 2],
95    pub mem_as: T,
96    /// prev_data will be provided by the core chip to make a complete MemoryWriteAuxCols
97    pub write_base_aux: MemoryBaseAuxCols<T>,
98    /// Only writes if `needs_write`.
99    /// If the instruction is a Load:
100    /// - Sets `needs_write` to 0 iff `rd == x0`
101    ///
102    /// Otherwise:
103    /// - Sets `needs_write` to 1
104    pub needs_write: T,
105}
106
107#[derive(Clone, Copy, Debug, derive_new::new)]
108pub struct Rv32LoadStoreAdapterAir {
109    pub(super) memory_bridge: MemoryBridge,
110    pub(super) execution_bridge: ExecutionBridge,
111    pub range_bus: VariableRangeCheckerBus,
112    pointer_max_bits: usize,
113}
114
115impl<F: Field> BaseAir<F> for Rv32LoadStoreAdapterAir {
116    fn width(&self) -> usize {
117        Rv32LoadStoreAdapterCols::<F>::width()
118    }
119}
120
121impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32LoadStoreAdapterAir {
122    type Interface = Rv32LoadStoreAdapterAirInterface<AB>;
123
124    fn eval(
125        &self,
126        builder: &mut AB,
127        local: &[AB::Var],
128        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
129    ) {
130        let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
131
132        let timestamp: AB::Var = local_cols.from_state.timestamp;
133        let mut timestamp_delta: usize = 0;
134        let mut timestamp_pp = || {
135            timestamp_delta += 1;
136            timestamp + AB::Expr::from_usize(timestamp_delta - 1)
137        };
138
139        let is_load = ctx.instruction.is_load;
140        let is_valid = ctx.instruction.is_valid;
141        let load_shift_amount = ctx.instruction.load_shift_amount;
142        let store_shift_amount = ctx.instruction.store_shift_amount;
143        let shift_amount = load_shift_amount.clone() + store_shift_amount.clone();
144
145        let write_count = local_cols.needs_write;
146
147        // This constraint ensures that the memory write only occurs when `is_valid == 1`.
148        builder.assert_bool(write_count);
149        builder.when(write_count).assert_one(is_valid.clone());
150
151        // Constrain that if `is_valid == 1` and `write_count == 0`, then `is_load == 1` and
152        // `rd_rs2_ptr == x0`
153        builder
154            .when(is_valid.clone() - write_count)
155            .assert_one(is_load.clone());
156        builder
157            .when(is_valid.clone() - write_count)
158            .assert_zero(local_cols.rd_rs2_ptr);
159
160        // read rs1
161        self.memory_bridge
162            .read(
163                MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.rs1_ptr),
164                local_cols.rs1_data,
165                timestamp_pp(),
166                &local_cols.rs1_aux_cols,
167            )
168            .eval(builder, is_valid.clone());
169
170        // constrain mem_ptr = rs1 + imm as a u32 addition with 2 limbs
171        let limbs_01 =
172            local_cols.rs1_data[0] + local_cols.rs1_data[1] * AB::F::from_u32(1 << RV32_CELL_BITS);
173        let limbs_23 =
174            local_cols.rs1_data[2] + local_cols.rs1_data[3] * AB::F::from_u32(1 << RV32_CELL_BITS);
175
176        let inv = AB::F::from_u32(1 << (RV32_CELL_BITS * 2)).inverse();
177        let carry = (limbs_01 + local_cols.imm - local_cols.mem_ptr_limbs[0]) * inv;
178
179        builder.when(is_valid.clone()).assert_bool(carry.clone());
180
181        builder
182            .when(is_valid.clone())
183            .assert_bool(local_cols.imm_sign);
184        let imm_extend_limb =
185            local_cols.imm_sign * AB::F::from_u32((1 << (RV32_CELL_BITS * 2)) - 1);
186        let carry = (limbs_23 + imm_extend_limb + carry - local_cols.mem_ptr_limbs[1]) * inv;
187        builder.when(is_valid.clone()).assert_bool(carry.clone());
188
189        // preventing mem_ptr overflow
190        self.range_bus
191            .range_check(
192                // (limb[0] - shift_amount) / 4 < 2^14 => limb[0] - shift_amount < 2^16
193                (local_cols.mem_ptr_limbs[0] - shift_amount) * AB::F::from_u32(4).inverse(),
194                RV32_CELL_BITS * 2 - 2,
195            )
196            .eval(builder, is_valid.clone());
197        self.range_bus
198            .range_check(
199                local_cols.mem_ptr_limbs[1],
200                self.pointer_max_bits - RV32_CELL_BITS * 2,
201            )
202            .eval(builder, is_valid.clone());
203
204        let mem_ptr = local_cols.mem_ptr_limbs[0]
205            + local_cols.mem_ptr_limbs[1] * AB::F::from_u32(1 << (RV32_CELL_BITS * 2));
206
207        let is_store = is_valid.clone() - is_load.clone();
208        // constrain mem_as to be in {0, 1, 2} if the instruction is a load,
209        // and in {2, 3, 4} if the instruction is a store
210        builder.assert_tern(local_cols.mem_as - is_store * AB::Expr::TWO);
211        builder
212            .when(not::<AB::Expr>(is_valid.clone()))
213            .assert_zero(local_cols.mem_as);
214
215        // read_as is [local_cols.mem_as] for loads and 1 for stores
216        let read_as = select::<AB::Expr>(
217            is_load.clone(),
218            local_cols.mem_as,
219            AB::F::from_u32(RV32_REGISTER_AS),
220        );
221
222        // read_ptr is mem_ptr for loads and rd_rs2_ptr for stores
223        // Note: shift_amount is expected to have degree 2, thus we can't put it in the select
224        // clause       since the resulting read_ptr/write_ptr's degree will be 3 which is
225        // too high.       Instead, the solution without using additional columns is to get
226        // two different shift amounts from core chip
227        let read_ptr = select::<AB::Expr>(is_load.clone(), mem_ptr.clone(), local_cols.rd_rs2_ptr)
228            - load_shift_amount;
229
230        self.memory_bridge
231            .read(
232                MemoryAddress::new(read_as, read_ptr),
233                ctx.reads.1,
234                timestamp_pp(),
235                &local_cols.read_data_aux,
236            )
237            .eval(builder, is_valid.clone());
238
239        let write_aux_cols = MemoryWriteAuxCols::from_base(local_cols.write_base_aux, ctx.reads.0);
240
241        // write_as is 1 for loads and [local_cols.mem_as] for stores
242        let write_as = select::<AB::Expr>(
243            is_load.clone(),
244            AB::F::from_u32(RV32_REGISTER_AS),
245            local_cols.mem_as,
246        );
247
248        // write_ptr is rd_rs2_ptr for loads and mem_ptr for stores
249        let write_ptr = select::<AB::Expr>(is_load.clone(), local_cols.rd_rs2_ptr, mem_ptr.clone())
250            - store_shift_amount;
251
252        self.memory_bridge
253            .write(
254                MemoryAddress::new(write_as, write_ptr),
255                ctx.writes[0].clone(),
256                timestamp_pp(),
257                &write_aux_cols,
258            )
259            .eval(builder, write_count);
260
261        let to_pc = ctx
262            .to_pc
263            .unwrap_or(local_cols.from_state.pc + AB::F::from_u32(DEFAULT_PC_STEP));
264        self.execution_bridge
265            .execute(
266                ctx.instruction.opcode,
267                [
268                    local_cols.rd_rs2_ptr.into(),
269                    local_cols.rs1_ptr.into(),
270                    local_cols.imm.into(),
271                    AB::Expr::from_u32(RV32_REGISTER_AS),
272                    local_cols.mem_as.into(),
273                    local_cols.needs_write.into(),
274                    local_cols.imm_sign.into(),
275                ],
276                local_cols.from_state,
277                ExecutionState {
278                    pc: to_pc,
279                    timestamp: timestamp + AB::F::from_usize(timestamp_delta),
280                },
281            )
282            .eval(builder, is_valid);
283    }
284
285    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
286        let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
287        local_cols.from_state.pc
288    }
289}
290
291#[repr(C)]
292#[derive(AlignedBytesBorrow, Debug)]
293pub struct Rv32LoadStoreAdapterRecord {
294    pub from_pc: u32,
295    pub from_timestamp: u32,
296
297    pub rs1_ptr: u32,
298    pub rs1_val: u32,
299    pub rs1_aux_record: MemoryReadAuxRecord,
300
301    pub rd_rs2_ptr: u32,
302    pub read_data_aux: MemoryReadAuxRecord,
303    pub imm: u16,
304    pub imm_sign: bool,
305
306    pub mem_as: u8,
307
308    pub write_prev_timestamp: u32,
309}
310
311/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm.
312/// In case of Loads, reads from the shifted intermediate pointer and writes to rd.
313/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer.
314#[derive(Clone, Copy, derive_new::new)]
315pub struct Rv32LoadStoreAdapterExecutor {
316    pointer_max_bits: usize,
317}
318
319#[derive(derive_new::new)]
320pub struct Rv32LoadStoreAdapterFiller {
321    pointer_max_bits: usize,
322    pub range_checker_chip: SharedVariableRangeCheckerChip,
323}
324
325impl<F> AdapterTraceExecutor<F> for Rv32LoadStoreAdapterExecutor
326where
327    F: PrimeField32,
328{
329    const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
330    type ReadData = (
331        (
332            [u32; RV32_REGISTER_NUM_LIMBS],
333            [u8; RV32_REGISTER_NUM_LIMBS],
334        ),
335        u8,
336    );
337    type WriteData = [u32; RV32_REGISTER_NUM_LIMBS];
338    type RecordMut<'a> = &'a mut Rv32LoadStoreAdapterRecord;
339
340    #[inline(always)]
341    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
342        record.from_pc = pc;
343        record.from_timestamp = memory.timestamp;
344    }
345
346    #[inline(always)]
347    fn read(
348        &self,
349        memory: &mut TracingMemory,
350        instruction: &Instruction<F>,
351        record: &mut Self::RecordMut<'_>,
352    ) -> Self::ReadData {
353        let &Instruction {
354            opcode,
355            a,
356            b,
357            c,
358            d,
359            e,
360            g,
361            ..
362        } = instruction;
363
364        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
365
366        let local_opcode = Rv32LoadStoreOpcode::from_usize(
367            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
368        );
369
370        record.rs1_ptr = b.as_canonical_u32();
371        record.rs1_val = u32::from_le_bytes(tracing_read(
372            memory,
373            RV32_REGISTER_AS,
374            record.rs1_ptr,
375            &mut record.rs1_aux_record.prev_timestamp,
376        ));
377
378        record.imm = c.as_canonical_u32() as u16;
379        record.imm_sign = g.is_one();
380        let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
381
382        let ptr_val = record.rs1_val.wrapping_add(imm_extended);
383        let shift_amount = ptr_val & 3;
384        let ptr_val = ptr_val - shift_amount;
385
386        assert!(
387            ptr_val < (1 << self.pointer_max_bits),
388            "ptr_val: {ptr_val} = rs1_val: {} + imm_extended: {imm_extended} >= 2 ** {}",
389            record.rs1_val,
390            self.pointer_max_bits
391        );
392
393        // prev_data: We need to keep values of some cells to keep them unchanged when writing to
394        // those cells
395        let (read_data, prev_data) = match local_opcode {
396            LOADW | LOADB | LOADH | LOADBU | LOADHU => {
397                debug_assert_eq!(e, F::from_u32(RV32_MEMORY_AS));
398                record.mem_as = RV32_MEMORY_AS as u8;
399                let read_data = tracing_read(
400                    memory,
401                    RV32_MEMORY_AS,
402                    ptr_val,
403                    &mut record.read_data_aux.prev_timestamp,
404                );
405                let prev_data = memory_read(memory.data(), RV32_REGISTER_AS, a.as_canonical_u32())
406                    .map(u32::from);
407                (read_data, prev_data)
408            }
409            STOREW | STOREH | STOREB => {
410                let e = e.as_canonical_u32();
411                debug_assert_ne!(e, RV32_IMM_AS);
412                debug_assert_ne!(e, RV32_REGISTER_AS);
413                record.mem_as = e as u8;
414                let read_data = tracing_read(
415                    memory,
416                    RV32_REGISTER_AS,
417                    a.as_canonical_u32(),
418                    &mut record.read_data_aux.prev_timestamp,
419                );
420                let prev_data = if e == NATIVE_AS {
421                    memory_read_native(memory.data(), ptr_val).map(|x: F| x.as_canonical_u32())
422                } else {
423                    memory_read(memory.data(), e, ptr_val).map(u32::from)
424                };
425                (read_data, prev_data)
426            }
427        };
428
429        ((prev_data, read_data), shift_amount as u8)
430    }
431
432    #[inline(always)]
433    fn write(
434        &self,
435        memory: &mut TracingMemory,
436        instruction: &Instruction<F>,
437        data: Self::WriteData,
438        record: &mut Self::RecordMut<'_>,
439    ) {
440        let &Instruction {
441            opcode,
442            a,
443            d,
444            e,
445            f: enabled,
446            ..
447        } = instruction;
448
449        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
450        debug_assert_ne!(e.as_canonical_u32(), RV32_IMM_AS);
451        debug_assert_ne!(e.as_canonical_u32(), RV32_REGISTER_AS);
452
453        let local_opcode = Rv32LoadStoreOpcode::from_usize(
454            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
455        );
456
457        if enabled != F::ZERO {
458            record.rd_rs2_ptr = a.as_canonical_u32();
459
460            record.write_prev_timestamp = match local_opcode {
461                STOREW | STOREH | STOREB => {
462                    let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
463                    let ptr = record.rs1_val.wrapping_add(imm_extended) & !3;
464
465                    if record.mem_as == 4 {
466                        timed_write_native(memory, ptr, data.map(F::from_u32)).0
467                    } else {
468                        timed_write(memory, record.mem_as as u32, ptr, data.map(|x| x as u8)).0
469                    }
470                }
471                LOADW | LOADB | LOADH | LOADBU | LOADHU => {
472                    timed_write(
473                        memory,
474                        RV32_REGISTER_AS,
475                        record.rd_rs2_ptr,
476                        data.map(|x| x as u8),
477                    )
478                    .0
479                }
480            };
481        } else {
482            record.rd_rs2_ptr = u32::MAX;
483            memory.increment_timestamp();
484        };
485    }
486}
487
488impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32LoadStoreAdapterFiller {
489    const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
490
491    #[inline(always)]
492    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
493        debug_assert!(self.range_checker_chip.range_max_bits() >= 15);
494
495        // SAFETY:
496        // - caller ensures `adapter_row` contains a valid record representation that was previously
497        //   written by the executor
498        // - get_record_from_slice correctly interprets the bytes as Rv32LoadStoreAdapterRecord
499        let record: &Rv32LoadStoreAdapterRecord =
500            unsafe { get_record_from_slice(&mut adapter_row, ()) };
501        let adapter_row: &mut Rv32LoadStoreAdapterCols<F> = adapter_row.borrow_mut();
502
503        let needs_write = record.rd_rs2_ptr != u32::MAX;
504        // Writing in reverse order
505        adapter_row.needs_write = F::from_bool(needs_write);
506
507        if needs_write {
508            mem_helper.fill(
509                record.write_prev_timestamp,
510                record.from_timestamp + 2,
511                &mut adapter_row.write_base_aux,
512            );
513        } else {
514            mem_helper.fill_zero(&mut adapter_row.write_base_aux);
515        }
516
517        adapter_row.mem_as = F::from_u8(record.mem_as);
518        let ptr = record
519            .rs1_val
520            .wrapping_add(record.imm as u32 + record.imm_sign as u32 * 0xffff0000);
521
522        let ptr_limbs = [ptr & 0xffff, ptr >> 16];
523        self.range_checker_chip
524            .add_count(ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2);
525        self.range_checker_chip
526            .add_count(ptr_limbs[1], self.pointer_max_bits - 16);
527        adapter_row.mem_ptr_limbs = ptr_limbs.map(F::from_u32);
528
529        adapter_row.imm_sign = F::from_bool(record.imm_sign);
530        adapter_row.imm = F::from_u16(record.imm);
531
532        mem_helper.fill(
533            record.read_data_aux.prev_timestamp,
534            record.from_timestamp + 1,
535            adapter_row.read_data_aux.as_mut(),
536        );
537        adapter_row.rd_rs2_ptr = if record.rd_rs2_ptr != u32::MAX {
538            F::from_u32(record.rd_rs2_ptr)
539        } else {
540            F::ZERO
541        };
542
543        mem_helper.fill(
544            record.rs1_aux_record.prev_timestamp,
545            record.from_timestamp,
546            adapter_row.rs1_aux_cols.as_mut(),
547        );
548
549        adapter_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_u8);
550        adapter_row.rs1_ptr = F::from_u32(record.rs1_ptr);
551
552        adapter_row.from_state.timestamp = F::from_u32(record.from_timestamp);
553        adapter_row.from_state.pc = F::from_u32(record.from_pc);
554    }
555}