openvm_rv32im_circuit/adapters/
mod.rs

1use std::ops::Mul;
2
3use openvm_circuit::{
4    arch::{execution_mode::ExecutionCtxTrait, VmStateMut},
5    system::memory::{
6        merkle::public_values::PUBLIC_VALUES_AS,
7        online::{GuestMemory, TracingMemory},
8    },
9};
10use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS};
11use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32};
12
13mod alu;
14mod branch;
15mod jalr;
16mod loadstore;
17mod mul;
18mod rdwrite;
19
20pub use alu::*;
21pub use branch::*;
22pub use jalr::*;
23pub use loadstore::*;
24pub use mul::*;
25pub use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
26pub use rdwrite::*;
27
28/// 256-bit heap integer stored as 32 bytes (32 limbs of 8-bits)
29pub const INT256_NUM_LIMBS: usize = 32;
30
31// For soundness, should be <= 16
32pub const RV_IS_TYPE_IMM_BITS: usize = 12;
33
34// Branch immediate value is in [-2^12, 2^12)
35pub const RV_B_TYPE_IMM_BITS: usize = 13;
36
37pub const RV_J_TYPE_IMM_BITS: usize = 21;
38
39/// Convert the RISC-V register data (32 bits represented as 4 bytes, where each byte is represented
40/// as a field element) back into its value as u32.
41pub fn compose<F: PrimeField32>(ptr_data: [F; RV32_REGISTER_NUM_LIMBS]) -> u32 {
42    let mut val = 0;
43    for (i, limb) in ptr_data.map(|x| x.as_canonical_u32()).iter().enumerate() {
44        val += limb << (i * 8);
45    }
46    val
47}
48
49/// inverse of `compose`
50pub fn decompose<F: PrimeField32>(value: u32) -> [F; RV32_REGISTER_NUM_LIMBS] {
51    std::array::from_fn(|i| {
52        F::from_canonical_u32((value >> (RV32_CELL_BITS * i)) & ((1 << RV32_CELL_BITS) - 1))
53    })
54}
55
56#[inline(always)]
57pub fn imm_to_bytes(imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] {
58    debug_assert_eq!(imm >> 24, 0);
59    let mut imm_le = imm.to_le_bytes();
60    imm_le[3] = imm_le[2];
61    imm_le
62}
63
64#[inline(always)]
65pub fn memory_read<const N: usize>(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] {
66    debug_assert!(
67        address_space == RV32_REGISTER_AS
68            || address_space == RV32_MEMORY_AS
69            || address_space == PUBLIC_VALUES_AS,
70    );
71
72    // SAFETY:
73    // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
74    //   minimum alignment of `RV32_REGISTER_NUM_LIMBS`
75    unsafe { memory.read::<u8, N>(address_space, ptr) }
76}
77
78#[inline(always)]
79pub fn memory_write<const N: usize>(
80    memory: &mut GuestMemory,
81    address_space: u32,
82    ptr: u32,
83    data: [u8; N],
84) {
85    debug_assert!(
86        address_space == RV32_REGISTER_AS
87            || address_space == RV32_MEMORY_AS
88            || address_space == PUBLIC_VALUES_AS
89    );
90
91    // SAFETY:
92    // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
93    //   minimum alignment of `RV32_REGISTER_NUM_LIMBS`
94    unsafe { memory.write::<u8, N>(address_space, ptr, data) }
95}
96
97/// Atomic read operation which increments the timestamp by 1.
98/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory
99/// access.
100#[inline(always)]
101pub fn timed_read<const N: usize>(
102    memory: &mut TracingMemory,
103    address_space: u32,
104    ptr: u32,
105) -> (u32, [u8; N]) {
106    debug_assert!(
107        address_space == RV32_REGISTER_AS
108            || address_space == RV32_MEMORY_AS
109            || address_space == PUBLIC_VALUES_AS
110    );
111
112    // SAFETY:
113    // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
114    //   minimum alignment of `RV32_REGISTER_NUM_LIMBS`
115    #[cfg(feature = "legacy-v1-3-mem-align")]
116    if address_space == RV32_MEMORY_AS {
117        unsafe { memory.read::<u8, N, 1>(address_space, ptr) }
118    } else {
119        unsafe { memory.read::<u8, N, RV32_REGISTER_NUM_LIMBS>(address_space, ptr) }
120    }
121    #[cfg(not(feature = "legacy-v1-3-mem-align"))]
122    unsafe {
123        memory.read::<u8, N, RV32_REGISTER_NUM_LIMBS>(address_space, ptr)
124    }
125}
126
127#[inline(always)]
128pub fn timed_write<const N: usize>(
129    memory: &mut TracingMemory,
130    address_space: u32,
131    ptr: u32,
132    data: [u8; N],
133) -> (u32, [u8; N]) {
134    debug_assert!(
135        address_space == RV32_REGISTER_AS
136            || address_space == RV32_MEMORY_AS
137            || address_space == PUBLIC_VALUES_AS
138    );
139
140    // SAFETY:
141    // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and
142    //   minimum alignment of `RV32_REGISTER_NUM_LIMBS`
143    #[cfg(feature = "legacy-v1-3-mem-align")]
144    if address_space == RV32_MEMORY_AS {
145        unsafe { memory.write::<u8, N, 1>(address_space, ptr, data) }
146    } else {
147        unsafe { memory.write::<u8, N, RV32_REGISTER_NUM_LIMBS>(address_space, ptr, data) }
148    }
149    #[cfg(not(feature = "legacy-v1-3-mem-align"))]
150    unsafe {
151        memory.write::<u8, N, RV32_REGISTER_NUM_LIMBS>(address_space, ptr, data)
152    }
153}
154
155/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer.
156/// Trace generation relevant to this memory access can be done fully from the recorded buffer.
157#[inline(always)]
158pub fn tracing_read<const N: usize>(
159    memory: &mut TracingMemory,
160    address_space: u32,
161    ptr: u32,
162    prev_timestamp: &mut u32,
163) -> [u8; N] {
164    let (t_prev, data) = timed_read(memory, address_space, ptr);
165    *prev_timestamp = t_prev;
166    data
167}
168
169#[inline(always)]
170pub fn tracing_read_imm(
171    memory: &mut TracingMemory,
172    imm: u32,
173    imm_mut: &mut u32,
174) -> [u8; RV32_REGISTER_NUM_LIMBS] {
175    *imm_mut = imm;
176    debug_assert_eq!(imm >> 24, 0); // highest byte should be zero to prevent overflow
177
178    memory.increment_timestamp();
179
180    let mut imm_le = imm.to_le_bytes();
181    // Important: we set the highest byte equal to the second highest byte, using the assumption
182    // that imm is at most 24 bits
183    imm_le[3] = imm_le[2];
184    imm_le
185}
186
187/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer.
188/// Trace generation relevant to this memory access can be done fully from the recorded buffer.
189#[inline(always)]
190pub fn tracing_write<const N: usize>(
191    memory: &mut TracingMemory,
192    address_space: u32,
193    ptr: u32,
194    data: [u8; N],
195    prev_timestamp: &mut u32,
196    prev_data: &mut [u8; N],
197) {
198    let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data);
199    *prev_timestamp = t_prev;
200    *prev_data = data_prev;
201}
202
203#[inline(always)]
204pub fn memory_read_from_state<F, Ctx, const N: usize>(
205    state: &mut VmStateMut<F, GuestMemory, Ctx>,
206    address_space: u32,
207    ptr: u32,
208) -> [u8; N]
209where
210    Ctx: ExecutionCtxTrait,
211{
212    state.ctx.on_memory_operation(address_space, ptr, N as u32);
213
214    memory_read(state.memory, address_space, ptr)
215}
216
217#[inline(always)]
218pub fn memory_write_from_state<F, Ctx, const N: usize>(
219    state: &mut VmStateMut<F, GuestMemory, Ctx>,
220    address_space: u32,
221    ptr: u32,
222    data: [u8; N],
223) where
224    Ctx: ExecutionCtxTrait,
225{
226    state.ctx.on_memory_operation(address_space, ptr, N as u32);
227
228    memory_write(state.memory, address_space, ptr, data)
229}
230
231#[inline(always)]
232pub fn read_rv32_register_from_state<F, Ctx>(
233    state: &mut VmStateMut<F, GuestMemory, Ctx>,
234    ptr: u32,
235) -> u32
236where
237    Ctx: ExecutionCtxTrait,
238{
239    u32::from_le_bytes(memory_read_from_state(state, RV32_REGISTER_AS, ptr))
240}
241
242#[inline(always)]
243pub fn read_rv32_register(memory: &GuestMemory, ptr: u32) -> u32 {
244    u32::from_le_bytes(memory_read(memory, RV32_REGISTER_AS, ptr))
245}
246
247pub fn abstract_compose<T: FieldAlgebra, V: Mul<T, Output = T>>(
248    data: [V; RV32_REGISTER_NUM_LIMBS],
249) -> T {
250    data.into_iter()
251        .enumerate()
252        .fold(T::ZERO, |acc, (i, limb)| {
253            acc + limb * T::from_canonical_u32(1 << (i * RV32_CELL_BITS))
254        })
255}
256
257// TEMP[jpw]
258pub fn tmp_convert_to_u8s<F: PrimeField32, const N: usize>(data: [F; N]) -> [u8; N] {
259    data.map(|x| x.as_canonical_u32() as u8)
260}