openvm_rv32im_circuit/loadstore/
execution.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    fmt::Debug,
4    mem::size_of,
5};
6
7use openvm_circuit::{
8    arch::*,
9    system::memory::{online::GuestMemory, POINTER_MAX_BITS},
10};
11use openvm_circuit_primitives::AlignedBytesBorrow;
12use openvm_instructions::{
13    instruction::Instruction,
14    program::DEFAULT_PC_STEP,
15    riscv::{RV32_IMM_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
16    LocalOpcode, NATIVE_AS,
17};
18use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
19use openvm_stark_backend::p3_field::PrimeField32;
20
21use super::core::LoadStoreExecutor;
22
23#[derive(AlignedBytesBorrow, Clone)]
24#[repr(C)]
25struct LoadStorePreCompute {
26    imm_extended: u32,
27    a: u8,
28    b: u8,
29    e: u8,
30}
31
32impl<A, const NUM_CELLS: usize> LoadStoreExecutor<A, NUM_CELLS> {
33    /// Return (local_opcode, enabled, is_native_store)
34    fn pre_compute_impl<F: PrimeField32>(
35        &self,
36        pc: u32,
37        inst: &Instruction<F>,
38        data: &mut LoadStorePreCompute,
39    ) -> Result<(Rv32LoadStoreOpcode, bool, bool), StaticProgramError> {
40        let Instruction {
41            opcode,
42            a,
43            b,
44            c,
45            d,
46            e,
47            f,
48            g,
49            ..
50        } = inst;
51        let enabled = !f.is_zero();
52
53        let e_u32 = e.as_canonical_u32();
54        if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 == RV32_IMM_AS {
55            return Err(StaticProgramError::InvalidInstruction(pc));
56        }
57
58        let local_opcode = Rv32LoadStoreOpcode::from_usize(
59            opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
60        );
61        match local_opcode {
62            LOADW | LOADBU | LOADHU => {}
63            STOREW | STOREH | STOREB => {
64                if !enabled {
65                    return Err(StaticProgramError::InvalidInstruction(pc));
66                }
67            }
68            _ => unreachable!("LoadStoreExecutor should not handle LOADB/LOADH opcodes"),
69        }
70
71        let imm = c.as_canonical_u32();
72        let imm_sign = g.as_canonical_u32();
73        let imm_extended = imm + imm_sign * 0xffff0000;
74        let is_native_store = e_u32 == NATIVE_AS;
75
76        *data = LoadStorePreCompute {
77            imm_extended,
78            a: a.as_canonical_u32() as u8,
79            b: b.as_canonical_u32() as u8,
80            e: e_u32 as u8,
81        };
82        Ok((local_opcode, enabled, is_native_store))
83    }
84}
85
86macro_rules! dispatch {
87    ($execute_impl:ident, $local_opcode:ident, $enabled:ident, $is_native_store:ident) => {
88        match ($local_opcode, $enabled, $is_native_store) {
89            (LOADW, true, _) => Ok($execute_impl::<_, _, U8, LoadWOp, true>),
90            (LOADW, false, _) => Ok($execute_impl::<_, _, U8, LoadWOp, false>),
91            (LOADHU, true, _) => Ok($execute_impl::<_, _, U8, LoadHUOp, true>),
92            (LOADHU, false, _) => Ok($execute_impl::<_, _, U8, LoadHUOp, false>),
93            (LOADBU, true, _) => Ok($execute_impl::<_, _, U8, LoadBUOp, true>),
94            (LOADBU, false, _) => Ok($execute_impl::<_, _, U8, LoadBUOp, false>),
95            (STOREW, true, false) => Ok($execute_impl::<_, _, U8, StoreWOp, true>),
96            (STOREW, false, false) => Ok($execute_impl::<_, _, U8, StoreWOp, false>),
97            (STOREW, true, true) => Ok($execute_impl::<_, _, F, StoreWOp, true>),
98            (STOREW, false, true) => Ok($execute_impl::<_, _, F, StoreWOp, false>),
99            (STOREH, true, false) => Ok($execute_impl::<_, _, U8, StoreHOp, true>),
100            (STOREH, false, false) => Ok($execute_impl::<_, _, U8, StoreHOp, false>),
101            (STOREH, true, true) => Ok($execute_impl::<_, _, F, StoreHOp, true>),
102            (STOREH, false, true) => Ok($execute_impl::<_, _, F, StoreHOp, false>),
103            (STOREB, true, false) => Ok($execute_impl::<_, _, U8, StoreBOp, true>),
104            (STOREB, false, false) => Ok($execute_impl::<_, _, U8, StoreBOp, false>),
105            (STOREB, true, true) => Ok($execute_impl::<_, _, F, StoreBOp, true>),
106            (STOREB, false, true) => Ok($execute_impl::<_, _, F, StoreBOp, false>),
107            (_, _, _) => unreachable!(),
108        }
109    };
110}
111
112impl<F, A, const NUM_CELLS: usize> InterpreterExecutor<F> for LoadStoreExecutor<A, NUM_CELLS>
113where
114    F: PrimeField32,
115{
116    #[inline(always)]
117    fn pre_compute_size(&self) -> usize {
118        size_of::<LoadStorePreCompute>()
119    }
120
121    #[cfg(not(feature = "tco"))]
122    #[inline(always)]
123    fn pre_compute<Ctx: ExecutionCtxTrait>(
124        &self,
125        pc: u32,
126        inst: &Instruction<F>,
127        data: &mut [u8],
128    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError> {
129        let pre_compute: &mut LoadStorePreCompute = data.borrow_mut();
130        let (local_opcode, enabled, is_native_store) =
131            self.pre_compute_impl(pc, inst, pre_compute)?;
132        dispatch!(execute_e1_handler, local_opcode, enabled, is_native_store)
133    }
134
135    #[cfg(feature = "tco")]
136    fn handler<Ctx>(
137        &self,
138        pc: u32,
139        inst: &Instruction<F>,
140        data: &mut [u8],
141    ) -> Result<Handler<F, Ctx>, StaticProgramError>
142    where
143        Ctx: ExecutionCtxTrait,
144    {
145        let pre_compute: &mut LoadStorePreCompute = data.borrow_mut();
146        let (local_opcode, enabled, is_native_store) =
147            self.pre_compute_impl(pc, inst, pre_compute)?;
148        dispatch!(execute_e1_handler, local_opcode, enabled, is_native_store)
149    }
150}
151
152impl<F, A, const NUM_CELLS: usize> InterpreterMeteredExecutor<F> for LoadStoreExecutor<A, NUM_CELLS>
153where
154    F: PrimeField32,
155{
156    fn metered_pre_compute_size(&self) -> usize {
157        size_of::<E2PreCompute<LoadStorePreCompute>>()
158    }
159
160    #[cfg(not(feature = "tco"))]
161    fn metered_pre_compute<Ctx>(
162        &self,
163        chip_idx: usize,
164        pc: u32,
165        inst: &Instruction<F>,
166        data: &mut [u8],
167    ) -> Result<ExecuteFunc<F, Ctx>, StaticProgramError>
168    where
169        Ctx: MeteredExecutionCtxTrait,
170    {
171        let pre_compute: &mut E2PreCompute<LoadStorePreCompute> = data.borrow_mut();
172        pre_compute.chip_idx = chip_idx as u32;
173        let (local_opcode, enabled, is_native_store) =
174            self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
175        dispatch!(execute_e2_handler, local_opcode, enabled, is_native_store)
176    }
177
178    #[cfg(feature = "tco")]
179    fn metered_handler<Ctx>(
180        &self,
181        chip_idx: usize,
182        pc: u32,
183        inst: &Instruction<F>,
184        data: &mut [u8],
185    ) -> Result<Handler<F, Ctx>, StaticProgramError>
186    where
187        Ctx: MeteredExecutionCtxTrait,
188    {
189        let pre_compute: &mut E2PreCompute<LoadStorePreCompute> = data.borrow_mut();
190        pre_compute.chip_idx = chip_idx as u32;
191        let (local_opcode, enabled, is_native_store) =
192            self.pre_compute_impl(pc, inst, &mut pre_compute.data)?;
193        dispatch!(execute_e2_handler, local_opcode, enabled, is_native_store)
194    }
195}
196
197#[inline(always)]
198unsafe fn execute_e12_impl<
199    F: PrimeField32,
200    CTX: ExecutionCtxTrait,
201    T: Copy + Debug + Default,
202    OP: LoadStoreOp<T>,
203    const ENABLED: bool,
204>(
205    pre_compute: &LoadStorePreCompute,
206    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
207) -> Result<(), ExecutionError> {
208    let pc = exec_state.pc();
209    let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] =
210        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32);
211    let rs1_val = u32::from_le_bytes(rs1_bytes);
212    let ptr_val = rs1_val.wrapping_add(pre_compute.imm_extended);
213    // sign_extend([r32{c,g}(b):2]_e)`
214    if ptr_val >= (1 << POINTER_MAX_BITS) {
215        println!(
216            "at {} ptr_val: {ptr_val} >= (1 << POINTER_MAX_BITS): {}",
217            pc,
218            1 << POINTER_MAX_BITS
219        );
220    }
221    debug_assert!(ptr_val < (1 << POINTER_MAX_BITS));
222
223    let shift_amount = ptr_val % 4;
224    let ptr_val = ptr_val - shift_amount; // aligned ptr
225
226    let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = if OP::IS_LOAD {
227        exec_state.vm_read(pre_compute.e as u32, ptr_val)
228    } else {
229        exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32)
230    };
231
232    // We need to write 4 u32s for STORE.
233    let mut write_data: [T; RV32_REGISTER_NUM_LIMBS] = if OP::HOST_READ {
234        exec_state.host_read(pre_compute.e as u32, ptr_val)
235    } else {
236        [T::default(); RV32_REGISTER_NUM_LIMBS]
237    };
238
239    if !OP::compute_write_data(&mut write_data, read_data, shift_amount as usize) {
240        let err = ExecutionError::Fail {
241            pc,
242            msg: "Invalid LoadStoreOp",
243        };
244        return Err(err);
245    }
246
247    if ENABLED {
248        if OP::IS_LOAD {
249            exec_state.vm_write(RV32_REGISTER_AS, pre_compute.a as u32, &write_data);
250        } else {
251            exec_state.vm_write(pre_compute.e as u32, ptr_val, &write_data);
252        }
253    }
254
255    exec_state.set_pc(pc.wrapping_add(DEFAULT_PC_STEP));
256
257    Ok(())
258}
259
260#[create_handler]
261#[inline(always)]
262unsafe fn execute_e1_impl<
263    F: PrimeField32,
264    CTX: ExecutionCtxTrait,
265    T: Copy + Debug + Default,
266    OP: LoadStoreOp<T>,
267    const ENABLED: bool,
268>(
269    pre_compute: *const u8,
270    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
271) -> Result<(), ExecutionError> {
272    let pre_compute: &LoadStorePreCompute =
273        std::slice::from_raw_parts(pre_compute, size_of::<LoadStorePreCompute>()).borrow();
274    execute_e12_impl::<F, CTX, T, OP, ENABLED>(pre_compute, exec_state)
275}
276
277#[create_handler]
278#[inline(always)]
279unsafe fn execute_e2_impl<
280    F: PrimeField32,
281    CTX: MeteredExecutionCtxTrait,
282    T: Copy + Debug + Default,
283    OP: LoadStoreOp<T>,
284    const ENABLED: bool,
285>(
286    pre_compute: *const u8,
287    exec_state: &mut VmExecState<F, GuestMemory, CTX>,
288) -> Result<(), ExecutionError> {
289    let pre_compute: &E2PreCompute<LoadStorePreCompute> =
290        std::slice::from_raw_parts(pre_compute, size_of::<E2PreCompute<LoadStorePreCompute>>())
291            .borrow();
292    exec_state
293        .ctx
294        .on_height_change(pre_compute.chip_idx as usize, 1);
295    execute_e12_impl::<F, CTX, T, OP, ENABLED>(&pre_compute.data, exec_state)
296}
297
298trait LoadStoreOp<T> {
299    const IS_LOAD: bool;
300    const HOST_READ: bool;
301
302    /// Return if the operation is valid.
303    fn compute_write_data(
304        write_data: &mut [T; RV32_REGISTER_NUM_LIMBS],
305        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
306        shift_amount: usize,
307    ) -> bool;
308}
309/// Wrapper type for u8 so we can implement `LoadStoreOp<F>` for `F: PrimeField32`.
310/// For memory read/write, this type behaves as same as `u8`.
311#[allow(dead_code)]
312#[derive(Copy, Clone, Debug, Default)]
313struct U8(u8);
314struct LoadWOp;
315struct LoadHUOp;
316struct LoadBUOp;
317struct StoreWOp;
318struct StoreHOp;
319struct StoreBOp;
320impl LoadStoreOp<U8> for LoadWOp {
321    const IS_LOAD: bool = true;
322    const HOST_READ: bool = false;
323
324    #[inline(always)]
325    fn compute_write_data(
326        write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS],
327        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
328        _shift_amount: usize,
329    ) -> bool {
330        *write_data = read_data.map(U8);
331        true
332    }
333}
334
335impl LoadStoreOp<U8> for LoadHUOp {
336    const IS_LOAD: bool = true;
337    const HOST_READ: bool = false;
338    #[inline(always)]
339    fn compute_write_data(
340        write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS],
341        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
342        shift_amount: usize,
343    ) -> bool {
344        if shift_amount != 0 && shift_amount != 2 {
345            return false;
346        }
347        write_data[0] = U8(read_data[shift_amount]);
348        write_data[1] = U8(read_data[shift_amount + 1]);
349        true
350    }
351}
352impl LoadStoreOp<U8> for LoadBUOp {
353    const IS_LOAD: bool = true;
354    const HOST_READ: bool = false;
355    #[inline(always)]
356    fn compute_write_data(
357        write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS],
358        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
359        shift_amount: usize,
360    ) -> bool {
361        write_data[0] = U8(read_data[shift_amount]);
362        true
363    }
364}
365
366impl LoadStoreOp<U8> for StoreWOp {
367    const IS_LOAD: bool = false;
368    const HOST_READ: bool = false;
369    #[inline(always)]
370    fn compute_write_data(
371        write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS],
372        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
373        _shift_amount: usize,
374    ) -> bool {
375        *write_data = read_data.map(U8);
376        true
377    }
378}
379impl LoadStoreOp<U8> for StoreHOp {
380    const IS_LOAD: bool = false;
381    const HOST_READ: bool = true;
382
383    #[inline(always)]
384    fn compute_write_data(
385        write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS],
386        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
387        shift_amount: usize,
388    ) -> bool {
389        if shift_amount != 0 && shift_amount != 2 {
390            return false;
391        }
392        write_data[shift_amount] = U8(read_data[0]);
393        write_data[shift_amount + 1] = U8(read_data[1]);
394        true
395    }
396}
397impl LoadStoreOp<U8> for StoreBOp {
398    const IS_LOAD: bool = false;
399    const HOST_READ: bool = true;
400    #[inline(always)]
401    fn compute_write_data(
402        write_data: &mut [U8; RV32_REGISTER_NUM_LIMBS],
403        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
404        shift_amount: usize,
405    ) -> bool {
406        write_data[shift_amount] = U8(read_data[0]);
407        true
408    }
409}
410
411impl<F: PrimeField32> LoadStoreOp<F> for StoreWOp {
412    const IS_LOAD: bool = false;
413    const HOST_READ: bool = false;
414    #[inline(always)]
415    fn compute_write_data(
416        write_data: &mut [F; RV32_REGISTER_NUM_LIMBS],
417        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
418        _shift_amount: usize,
419    ) -> bool {
420        *write_data = read_data.map(F::from_canonical_u8);
421        true
422    }
423}
424impl<F: PrimeField32> LoadStoreOp<F> for StoreHOp {
425    const IS_LOAD: bool = false;
426    const HOST_READ: bool = true;
427
428    #[inline(always)]
429    fn compute_write_data(
430        write_data: &mut [F; RV32_REGISTER_NUM_LIMBS],
431        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
432        shift_amount: usize,
433    ) -> bool {
434        if shift_amount != 0 && shift_amount != 2 {
435            return false;
436        }
437        write_data[shift_amount] = F::from_canonical_u8(read_data[0]);
438        write_data[shift_amount + 1] = F::from_canonical_u8(read_data[1]);
439        true
440    }
441}
442impl<F: PrimeField32> LoadStoreOp<F> for StoreBOp {
443    const IS_LOAD: bool = false;
444    const HOST_READ: bool = true;
445    #[inline(always)]
446    fn compute_write_data(
447        write_data: &mut [F; RV32_REGISTER_NUM_LIMBS],
448        read_data: [u8; RV32_REGISTER_NUM_LIMBS],
449        shift_amount: usize,
450    ) -> bool {
451        write_data[shift_amount] = F::from_canonical_u8(read_data[0]);
452        true
453    }
454}