openvm_rv32im_circuit/loadstore/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    fmt::Debug,
5};
6
7use openvm_circuit::{
8    arch::*,
9    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
10};
11use openvm_circuit_primitives::{AlignedBorrow, AlignedBytesBorrow};
12use openvm_instructions::{
13    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode,
14};
15use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
16use openvm_stark_backend::{
17    interaction::InteractionBuilder,
18    p3_air::{AirBuilder, BaseAir},
19    p3_field::{Field, FieldAlgebra, PrimeField32},
20    rap::BaseAirWithPublicValues,
21};
22
23use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
24
25#[derive(Debug, Clone, Copy)]
26enum InstructionOpcode {
27    LoadW0,
28    LoadHu0,
29    LoadHu2,
30    LoadBu0,
31    LoadBu1,
32    LoadBu2,
33    LoadBu3,
34    StoreW0,
35    StoreH0,
36    StoreH2,
37    StoreB0,
38    StoreB1,
39    StoreB2,
40    StoreB3,
41}
42
43use InstructionOpcode::*;
44
45/// LoadStore Core Chip handles byte/halfword into word conversions and unsigned extends
46/// This chip uses read_data and prev_data to constrain the write_data
47/// It also handles the shifting in case of not 4 byte aligned instructions
48/// This chips treats each (opcode, shift) pair as a separate instruction
49#[repr(C)]
50#[derive(Debug, Clone, AlignedBorrow)]
51pub struct LoadStoreCoreCols<T, const NUM_CELLS: usize> {
52    pub flags: [T; 4],
53    /// we need to keep the degree of is_valid and is_load to 1
54    pub is_valid: T,
55    pub is_load: T,
56
57    pub read_data: [T; NUM_CELLS],
58    pub prev_data: [T; NUM_CELLS],
59    /// write_data will be constrained against read_data and prev_data
60    /// depending on the opcode and the shift amount
61    pub write_data: [T; NUM_CELLS],
62}
63
64#[derive(Debug, Clone, derive_new::new)]
65pub struct LoadStoreCoreAir<const NUM_CELLS: usize> {
66    pub offset: usize,
67}
68
69impl<F: Field, const NUM_CELLS: usize> BaseAir<F> for LoadStoreCoreAir<NUM_CELLS> {
70    fn width(&self) -> usize {
71        LoadStoreCoreCols::<F, NUM_CELLS>::width()
72    }
73}
74
75impl<F: Field, const NUM_CELLS: usize> BaseAirWithPublicValues<F> for LoadStoreCoreAir<NUM_CELLS> {}
76
77impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for LoadStoreCoreAir<NUM_CELLS>
78where
79    AB: InteractionBuilder,
80    I: VmAdapterInterface<AB::Expr>,
81    I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
82    I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
83    I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
84{
85    fn eval(
86        &self,
87        builder: &mut AB,
88        local_core: &[AB::Var],
89        _from_pc: AB::Var,
90    ) -> AdapterAirContext<AB::Expr, I> {
91        let cols: &LoadStoreCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
92        let LoadStoreCoreCols::<AB::Var, NUM_CELLS> {
93            read_data,
94            prev_data,
95            write_data,
96            flags,
97            is_valid,
98            is_load,
99        } = *cols;
100
101        let get_expr_12 = |x: &AB::Expr| (x.clone() - AB::Expr::ONE) * (x.clone() - AB::Expr::TWO);
102
103        builder.assert_bool(is_valid);
104        let sum = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
105            builder.assert_zero(flag * get_expr_12(&flag.into()));
106            acc + flag
107        });
108        builder.assert_zero(sum.clone() * get_expr_12(&sum));
109        // when sum is 0, is_valid must be 0
110        builder.when(get_expr_12(&sum)).assert_zero(is_valid);
111
112        // We will use the InstructionOpcode enum to encode the opcodes
113        // the appended digit to each opcode is the shift amount
114        let inv_2 = AB::F::from_canonical_u32(2).inverse();
115        let mut opcode_flags = vec![];
116        for flag in flags {
117            opcode_flags.push(flag * (flag - AB::F::ONE) * inv_2);
118        }
119        for flag in flags {
120            opcode_flags.push(flag * (sum.clone() - AB::F::TWO) * AB::F::NEG_ONE);
121        }
122        (0..4).for_each(|i| {
123            ((i + 1)..4).for_each(|j| opcode_flags.push(flags[i] * flags[j]));
124        });
125
126        let opcode_when = |idxs: &[InstructionOpcode]| -> AB::Expr {
127            idxs.iter().fold(AB::Expr::ZERO, |acc, &idx| {
128                acc + opcode_flags[idx as usize].clone()
129            })
130        };
131
132        // Constrain that is_load matches the opcode
133        builder.assert_eq(
134            is_load,
135            opcode_when(&[LoadW0, LoadHu0, LoadHu2, LoadBu0, LoadBu1, LoadBu2, LoadBu3]),
136        );
137        builder.when(is_load).assert_one(is_valid);
138
139        // There are three parts to write_data:
140        // - 1st limb is always read_data
141        // - 2nd to (NUM_CELLS/2)th limbs are:
142        //   - read_data if loadw/loadhu/storew/storeh
143        //   - prev_data if storeb
144        //   - zero if loadbu
145        // - (NUM_CELLS/2 + 1)th to last limbs are:
146        //   - read_data if loadw/storew
147        //   - prev_data if storeb/storeh
148        //   - zero if loadbu/loadhu
149        // Shifting needs to be carefully handled in case by case basis
150        // refer to [run_write_data] for the expected behavior in each case
151        for (i, cell) in write_data.iter().enumerate() {
152            // handling loads, expected_load_val = 0 if a store operation is happening
153            let expected_load_val = if i == 0 {
154                opcode_when(&[LoadW0, LoadHu0, LoadBu0]) * read_data[0]
155                    + opcode_when(&[LoadBu1]) * read_data[1]
156                    + opcode_when(&[LoadHu2, LoadBu2]) * read_data[2]
157                    + opcode_when(&[LoadBu3]) * read_data[3]
158            } else if i < NUM_CELLS / 2 {
159                opcode_when(&[LoadW0, LoadHu0]) * read_data[i]
160                    + opcode_when(&[LoadHu2]) * read_data[i + 2]
161            } else {
162                opcode_when(&[LoadW0]) * read_data[i]
163            };
164
165            // handling stores, expected_store_val = 0 if a load operation is happening
166            let expected_store_val = if i == 0 {
167                opcode_when(&[StoreW0, StoreH0, StoreB0]) * read_data[i]
168                    + opcode_when(&[StoreH2, StoreB1, StoreB2, StoreB3]) * prev_data[i]
169            } else if i == 1 {
170                opcode_when(&[StoreB1]) * read_data[i - 1]
171                    + opcode_when(&[StoreW0, StoreH0]) * read_data[i]
172                    + opcode_when(&[StoreH2, StoreB0, StoreB2, StoreB3]) * prev_data[i]
173            } else if i == 2 {
174                opcode_when(&[StoreH2, StoreB2]) * read_data[i - 2]
175                    + opcode_when(&[StoreW0]) * read_data[i]
176                    + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB3]) * prev_data[i]
177            } else if i == 3 {
178                opcode_when(&[StoreB3]) * read_data[i - 3]
179                    + opcode_when(&[StoreH2]) * read_data[i - 2]
180                    + opcode_when(&[StoreW0]) * read_data[i]
181                    + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB2]) * prev_data[i]
182            } else {
183                opcode_when(&[StoreW0]) * read_data[i]
184                    + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3]) * prev_data[i]
185                    + opcode_when(&[StoreH0])
186                        * if i < NUM_CELLS / 2 {
187                            read_data[i]
188                        } else {
189                            prev_data[i]
190                        }
191                    + opcode_when(&[StoreH2])
192                        * if i - 2 < NUM_CELLS / 2 {
193                            read_data[i - 2]
194                        } else {
195                            prev_data[i]
196                        }
197            };
198            let expected_val = expected_load_val + expected_store_val;
199            builder.assert_eq(*cell, expected_val);
200        }
201
202        let expected_opcode = opcode_when(&[LoadW0]) * AB::Expr::from_canonical_u8(LOADW as u8)
203            + opcode_when(&[LoadHu0, LoadHu2]) * AB::Expr::from_canonical_u8(LOADHU as u8)
204            + opcode_when(&[LoadBu0, LoadBu1, LoadBu2, LoadBu3])
205                * AB::Expr::from_canonical_u8(LOADBU as u8)
206            + opcode_when(&[StoreW0]) * AB::Expr::from_canonical_u8(STOREW as u8)
207            + opcode_when(&[StoreH0, StoreH2]) * AB::Expr::from_canonical_u8(STOREH as u8)
208            + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3])
209                * AB::Expr::from_canonical_u8(STOREB as u8);
210        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode);
211
212        let load_shift_amount = opcode_when(&[LoadBu1]) * AB::Expr::ONE
213            + opcode_when(&[LoadHu2, LoadBu2]) * AB::Expr::TWO
214            + opcode_when(&[LoadBu3]) * AB::Expr::from_canonical_u32(3);
215
216        let store_shift_amount = opcode_when(&[StoreB1]) * AB::Expr::ONE
217            + opcode_when(&[StoreH2, StoreB2]) * AB::Expr::TWO
218            + opcode_when(&[StoreB3]) * AB::Expr::from_canonical_u32(3);
219
220        AdapterAirContext {
221            to_pc: None,
222            reads: (prev_data, read_data.map(|x| x.into())).into(),
223            writes: [write_data.map(|x| x.into())].into(),
224            instruction: LoadStoreInstruction {
225                is_valid: is_valid.into(),
226                opcode: expected_opcode,
227                is_load: is_load.into(),
228                load_shift_amount,
229                store_shift_amount,
230            }
231            .into(),
232        }
233    }
234
235    fn start_offset(&self) -> usize {
236        self.offset
237    }
238}
239
240#[repr(C)]
241#[derive(AlignedBytesBorrow, Debug)]
242pub struct LoadStoreCoreRecord<const NUM_CELLS: usize> {
243    pub local_opcode: u8,
244    pub shift_amount: u8,
245    pub read_data: [u8; NUM_CELLS],
246    // Note: `prev_data` can be from native address space, so we need to use u32
247    pub prev_data: [u32; NUM_CELLS],
248}
249
250#[derive(Clone, Copy, derive_new::new)]
251pub struct LoadStoreExecutor<A, const NUM_CELLS: usize> {
252    adapter: A,
253    pub offset: usize,
254}
255
256#[derive(Clone, derive_new::new)]
257pub struct LoadStoreFiller<
258    A = Rv32LoadStoreAdapterFiller,
259    const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
260> {
261    adapter: A,
262    pub offset: usize,
263}
264
265impl<F, A, RA, const NUM_CELLS: usize> PreflightExecutor<F, RA> for LoadStoreExecutor<A, NUM_CELLS>
266where
267    F: PrimeField32,
268    A: 'static
269        + AdapterTraceExecutor<
270            F,
271            ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
272            WriteData = [u32; NUM_CELLS],
273        >,
274    for<'buf> RA: RecordArena<
275        'buf,
276        EmptyAdapterCoreLayout<F, A>,
277        (A::RecordMut<'buf>, &'buf mut LoadStoreCoreRecord<NUM_CELLS>),
278    >,
279{
280    fn get_opcode_name(&self, opcode: usize) -> String {
281        format!(
282            "{:?}",
283            Rv32LoadStoreOpcode::from_usize(opcode - self.offset)
284        )
285    }
286
287    fn execute(
288        &self,
289        state: VmStateMut<F, TracingMemory, RA>,
290        instruction: &Instruction<F>,
291    ) -> Result<(), ExecutionError> {
292        let Instruction { opcode, .. } = instruction;
293
294        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
295
296        A::start(*state.pc, state.memory, &mut adapter_record);
297
298        (
299            (core_record.prev_data, core_record.read_data),
300            core_record.shift_amount,
301        ) = self
302            .adapter
303            .read(state.memory, instruction, &mut adapter_record);
304
305        let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));
306        core_record.local_opcode = local_opcode as u8;
307
308        let write_data = run_write_data(
309            local_opcode,
310            core_record.read_data,
311            core_record.prev_data,
312            core_record.shift_amount as usize,
313        );
314        self.adapter
315            .write(state.memory, instruction, write_data, &mut adapter_record);
316
317        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
318
319        Ok(())
320    }
321}
322
323impl<F, A, const NUM_CELLS: usize> TraceFiller<F> for LoadStoreFiller<A, NUM_CELLS>
324where
325    F: PrimeField32,
326    A: 'static + AdapterTraceFiller<F>,
327{
328    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
329        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
330        // LoadStoreCoreCols::width() elements
331        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
332        self.adapter.fill_trace_row(mem_helper, adapter_row);
333        // SAFETY: core_row contains a valid LoadStoreCoreRecord written by the executor
334        // during trace generation
335        let record: &LoadStoreCoreRecord<NUM_CELLS> =
336            unsafe { get_record_from_slice(&mut core_row, ()) };
337        let core_row: &mut LoadStoreCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
338
339        let opcode = Rv32LoadStoreOpcode::from_usize(record.local_opcode as usize);
340        let shift = record.shift_amount;
341
342        let write_data = run_write_data(opcode, record.read_data, record.prev_data, shift as usize);
343        // Writing in reverse order
344        core_row.write_data = write_data.map(F::from_canonical_u32);
345        core_row.prev_data = record.prev_data.map(F::from_canonical_u32);
346        core_row.read_data = record.read_data.map(F::from_canonical_u8);
347        core_row.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode));
348        core_row.is_valid = F::ONE;
349        let flags = &mut core_row.flags;
350        *flags = [F::ZERO; 4];
351        match (opcode, shift) {
352            (LOADW, 0) => flags[0] = F::TWO,
353            (LOADHU, 0) => flags[1] = F::TWO,
354            (LOADHU, 2) => flags[2] = F::TWO,
355            (LOADBU, 0) => flags[3] = F::TWO,
356
357            (LOADBU, 1) => flags[0] = F::ONE,
358            (LOADBU, 2) => flags[1] = F::ONE,
359            (LOADBU, 3) => flags[2] = F::ONE,
360            (STOREW, 0) => flags[3] = F::ONE,
361
362            (STOREH, 0) => (flags[0], flags[1]) = (F::ONE, F::ONE),
363            (STOREH, 2) => (flags[0], flags[2]) = (F::ONE, F::ONE),
364            (STOREB, 0) => (flags[0], flags[3]) = (F::ONE, F::ONE),
365            (STOREB, 1) => (flags[1], flags[2]) = (F::ONE, F::ONE),
366            (STOREB, 2) => (flags[1], flags[3]) = (F::ONE, F::ONE),
367            (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE),
368            _ => unreachable!(),
369        };
370    }
371}
372
373// Returns the write data
374#[inline(always)]
375pub(super) fn run_write_data<const NUM_CELLS: usize>(
376    opcode: Rv32LoadStoreOpcode,
377    read_data: [u8; NUM_CELLS],
378    prev_data: [u32; NUM_CELLS],
379    shift: usize,
380) -> [u32; NUM_CELLS] {
381    match (opcode, shift) {
382        (LOADW, 0) => {
383            read_data.map(|x| x as u32)
384        },
385        (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => {
386           let mut wrie_data = [0; NUM_CELLS];
387           wrie_data[0] = read_data[shift] as u32;
388           wrie_data
389        }
390        (LOADHU, 0) | (LOADHU, 2) => {
391            let mut write_data = [0; NUM_CELLS];
392            for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() {
393                *cell = read_data[i + shift] as u32;
394            }
395            write_data
396        }
397        (STOREW, 0) => {
398            read_data.map(|x| x as u32)
399        },
400        (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => {
401            let mut write_data = prev_data;
402            write_data[shift] = read_data[0] as u32;
403            write_data
404        }
405        (STOREH, 0) | (STOREH, 2) => {
406            array::from_fn(|i| {
407                if i >= shift && i < (NUM_CELLS / 2 + shift){
408                    read_data[i - shift] as u32
409                } else {
410                    prev_data[i]
411                }
412            })
413        }
414        // Currently the adapter AIR requires `ptr_val` to be aligned to the data size in bytes.
415        // The circuit requires that `shift = ptr_val % 4` so that `ptr_val - shift` is a multiple of 4.
416        // This requirement is non-trivial to remove, because we use it to ensure that `ptr_val - shift + 4 <= 2^pointer_max_bits`.
417        _ => unreachable!(
418            "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
419        ),
420    }
421}