openvm_rv32im_circuit/common/mod.rs
1#[cfg(feature = "aot")]
2pub(crate) use aot::*;
3
4#[cfg(feature = "aot")]
5mod aot {
6 use std::mem::offset_of;
7
8 pub(crate) use openvm_circuit::arch::aot::common::*;
9 use openvm_circuit::{
10 arch::{
11 execution_mode::{metered::memory_ctx::MemoryCtx, MeteredCtx},
12 AotError, SystemConfig, VmExecState, ADDR_SPACE_OFFSET,
13 },
14 system::memory::{merkle::public_values::PUBLIC_VALUES_AS, online::GuestMemory, CHUNK},
15 };
16 use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS};
17
18 /// The minimum block size is 4, but RISC-V `lb` only requires alignment of 1 and `lh` only
19 /// requires alignment of 2 because the instructions are implemented by doing an access of
20 /// block size 4.
21 const DEFAULT_U8_BLOCK_SIZE_BITS: u8 = 2;
22 /// This is DIRTY because PAGE_BITS is a generic parameter of E2 context.
23 const DEFAULT_PAGE_BITS: usize = 6;
24
25 pub(crate) fn gpr_to_rv32_register(gpr: &str, rv32_reg: u8) -> String {
26 let xmm_map_reg = rv32_reg / 2;
27 if rv32_reg % 2 == 0 {
28 format!(" pinsrd xmm{xmm_map_reg}, {gpr}, 0\n")
29 } else {
30 format!(" pinsrd xmm{xmm_map_reg}, {gpr}, 1\n")
31 }
32 }
33
34 pub(crate) fn address_space_start_to_gpr(address_space: u32, gpr: &str) -> String {
35 if address_space == 2 {
36 if REG_AS2_PTR != gpr {
37 return format!(" mov {gpr}, r15\n");
38 }
39 return "".to_string();
40 }
41
42 let xmm_map_reg = match address_space {
43 1 => "xmm0",
44 3 => "xmm1",
45 4 => "xmm2",
46 _ => unreachable!("Only address space 1, 2, 3, 4 is supported"),
47 };
48 format!(" pextrq {gpr}, {xmm_map_reg}, 1\n")
49 }
50
51 /*
52 input:
53 - riscv register number
54 - gpr register to write into
55 - is_gpr_force_write boolean
56
57 output:
58 - string representing the general purpose register that stores the value of register number `rv32_reg`
59 - emitted assembly string that performs the move
60 */
61 pub(crate) fn xmm_to_gpr(
62 rv32_reg: u8,
63 gpr: &str,
64 is_gpr_force_write: bool,
65 ) -> (String, String) {
66 if let Some(override_reg) = RISCV_TO_X86_OVERRIDE_MAP[rv32_reg as usize] {
67 // a/4 is overridden, b/4 is overridden
68 if is_gpr_force_write {
69 return (gpr.to_string(), format!(" mov {gpr}, {override_reg}\n"));
70 }
71 return (override_reg.to_string(), "".to_string());
72 }
73 let xmm_map_reg = rv32_reg / 2;
74 if rv32_reg % 2 == 0 {
75 (
76 gpr.to_string(),
77 format!(" pextrd {gpr}, xmm{xmm_map_reg}, 0\n"),
78 )
79 } else {
80 (
81 gpr.to_string(),
82 format!(" pextrd {gpr}, xmm{xmm_map_reg}, 1\n"),
83 )
84 }
85 }
86
87 pub(crate) fn gpr_to_xmm(gpr: &str, rv32_reg: u8) -> String {
88 if let Some(override_reg) = RISCV_TO_X86_OVERRIDE_MAP[rv32_reg as usize] {
89 if gpr == override_reg {
90 //already in correct location
91 return "".to_string();
92 }
93 return format!(" mov {override_reg}, {gpr}\n");
94 }
95 let xmm_map_reg = rv32_reg / 2;
96 if rv32_reg % 2 == 0 {
97 format!(" pinsrd xmm{xmm_map_reg}, {gpr}, 0\n")
98 } else {
99 format!(" pinsrd xmm{xmm_map_reg}, {gpr}, 1\n")
100 }
101 }
102 pub(crate) fn update_adapter_heights_asm(
103 config: &SystemConfig,
104 _address_space: u32,
105 ) -> Result<String, AotError> {
106 let min_block_size_bits = config.memory_config.min_block_size_bits();
107 if min_block_size_bits[RV32_REGISTER_AS as usize] != DEFAULT_U8_BLOCK_SIZE_BITS {
108 println!("RV32_REGISTER_AS must have a minimum block size of 4");
109 return Err(AotError::Other(String::from(
110 "RV32_REGISTER_AS must have a minimum block size of 4",
111 )));
112 }
113 if min_block_size_bits[RV32_MEMORY_AS as usize] != DEFAULT_U8_BLOCK_SIZE_BITS {
114 println!("RV32_MEMORY_AS must have a minimum block size of 4");
115 return Err(AotError::Other(String::from(
116 "RV32_MEMORY_AS must have a minimum block size of 4",
117 )));
118 }
119 if min_block_size_bits[PUBLIC_VALUES_AS as usize] != DEFAULT_U8_BLOCK_SIZE_BITS {
120 println!("PUBLIC_VALUES_AS must have a minimum block size of 4");
121 return Err(AotError::Other(String::from(
122 "PUBLIC_VALUES_AS must have a minimum block size of 4",
123 )));
124 }
125
126 // `update_adapter_heights_asm` rewrites the following code in ASM for
127 // `on_memory_operation`: ```
128 // pub fn update_adapter_heights_batch(
129 // &self,
130 // trace_heights: &mut [u32],
131 // address_space: u32,
132 // size_bits: u32,
133 // num: u32,
134 // ) {
135 // let align_bits = unsafe {
136 // *self
137 // .min_block_size_bits
138 // .get_unchecked(address_space as usize)
139 // };
140 //
141 // for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
142 // let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
143 // debug_assert!(adapter_idx < trace_heights.len());
144 // unsafe {
145 // *trace_heights.get_unchecked_mut(adapter_idx) +=
146 // num << (size_bits - adapter_bits + 1);
147 // }
148 // }
149 // }
150 // ```
151 //
152 // For a specific RV32 instruction, the variables can be treated as constants at AOT
153 // compilation time:
154 // - `address_space`: always a constant because it is derived from an Instruction
155 // - `num`: always 1 in `on_memory_operation`
156 // - `align_bits`: always a constant because `address_space` is a constant
157 // - `size_bits`: RV32 instruction always read 4 bytes(in the AIR level). So `size` is
158 // always 4 bytes. So `size_bits` is always 2.
159 //
160 // If we ignore the Native address space, `min_block_size_bits`` is always
161 // `DEFAULT_U8_BLOCK_SIZE=4`. Therefore, `align_bits` is always 2. So the loop will
162 // never be executed and we can leave the function empty.
163 Ok("".to_string())
164 }
165
166 /// Generate ASM code for updating the boundary merkle heights.
167 ///
168 /// # Arguments
169 ///
170 /// * `config` - The system configuration.
171 /// * `address_space` - The address space.
172 /// * `pc` - The program counter of the current instruction.
173 /// * `ptr_reg` - The register to store the accessed pointer. The caller should not expect the
174 /// value of this register to be preserved.
175 /// * `reg1` - A register to store the intermediate result.
176 /// * `reg2` - A register to store the intermediate result.
177 ///
178 /// # Returns
179 ///
180 /// The ASM code for updating the boundary merkle heights.
181 pub(crate) fn update_boundary_merkle_heights_asm<F>(
182 config: &SystemConfig,
183 address_space: u32,
184 pc: u32,
185 ptr_reg: &str,
186 reg1: &str,
187 reg2: &str,
188 ) -> Result<String, AotError> {
189 // `update_boundary_merkle_heights_asm` rewrites the following code in ASM for
190 // `on_memory_operation`: ```
191 // pub fn label_to_index((addr_space, block_id): (u32, u32)) -> u64 {
192 // (((addr_space - ADDR_SPACE_OFFSET) as u64) << self.address_height) + block_id as u64
193 // }
194 //
195 // pub(crate) fn update_boundary_merkle_heights(
196 // &mut self,
197 // address_space: u32,
198 // ptr: u32,
199 // size: u32,
200 // ) {
201 // let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
202 // let start_chunk_id = ptr >> self.chunk_bits;
203 // let start_block_id = if self.chunk == 1 {
204 // start_chunk_id
205 // } else {
206 // self.memory_dimensions
207 // .label_to_index((address_space, start_chunk_id)) as u32
208 // };
209 // // Because `self.chunk == 1 << self.chunk_bits`
210 // let end_block_id = start_block_id + num_blocks;
211 // let start_page_id = start_block_id >> PAGE_BITS;
212 // let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
213
214 // for page_id in start_page_id..end_page_id {
215 // // Append page_id to page_indices_since_checkpoint
216 // let len = self.page_indices_since_checkpoint_len;
217 // // SAFETY: len is within bounds, and we extend length by 1 after writing.
218 // unsafe {
219 // *self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id;
220 // }
221 // self.page_indices_since_checkpoint_len = len + 1;
222 //
223 // if self.page_indices.insert(page_id as usize) {
224 // // SAFETY: address_space passed is usually a hardcoded constant or derived
225 // from an // Instruction where it is bounds checked before passing
226 // unsafe {
227 // *self
228 // .addr_space_access_count
229 // .get_unchecked_mut(address_space as usize) += 1;
230 // }
231 // }
232 // }
233 // }
234 // ```
235 //
236 // For a specific RV32 instruction, the variables can be treated as constants at AOT compilation time:
237 // Inputs:
238 // - `chunk`: always 8(CHUNK) because we only support when continuation is enabled.
239 // - `address_space`: always a constant because it is derived from an Instruction
240 // - `size`: RV32 instruction always read 4 bytes(in the AIR level).
241 // - `self.memory_dimensions.address_height`: known at AOT compilation time because it is derived from the memory configuration.
242 // Inside the function body:
243 // - `num_blocks`: `(size + self.chunk - 1) >> self.chunk_bits = (4 + 8 - 1) >> 3 = 1`
244 // - `as_offset = (addr_space - ADDR_SPACE_OFFSET) as u64) << self.address_height)`: constant because `address_space` and `address_height` constant
245 // - `start_chunk_id`: `ptr >> self.chunk_bits`
246 // - `start_block_id`: `start_chunk_id + as_offset`
247 // - `end_block_id`: `start_block_id + num_blocks = start_block_id +1`
248 // - `start_page_id`: `start_block_id >> PAGE_BITS`
249 // - `end_page_id`: ((end_block_id - 1) >> PAGE_BITS) + 1 = start_block_id >> PAGE_BITS + 1;
250 //
251 // Therefore the loop only iterates once for `page_id = start_page_id`.
252
253 let initial_block_size: usize = config.initial_block_size();
254 if initial_block_size != CHUNK {
255 return Err(AotError::Other(format!(
256 "initial_block_size must be {CHUNK}, got {initial_block_size}"
257 )));
258 }
259 let chunk_bits = CHUNK.ilog2();
260 let as_offset = ((address_space - ADDR_SPACE_OFFSET) as u64)
261 << (config.memory_config.memory_dimensions().address_height);
262
263 let mut asm_str = String::new();
264 // `start_chunk_id`: `ptr >> self.chunk_bits`
265 asm_str += &format!(" shr {ptr_reg}, {chunk_bits}\n");
266 // `start_block_id`: `start_chunk_id + as_offset`
267 asm_str += &format!(" add {ptr_reg}, {as_offset}\n");
268 // `start_page_id`: `start_block_id >> PAGE_BITS`
269 // NOTE: This is DIRTY because PAGE_BITS is a generic parameter of E2 context.
270 asm_str += &format!(" shr {ptr_reg}, {DEFAULT_PAGE_BITS}\n");
271
272 let memory_ctx_offset = offset_of!(VmExecState<F, GuestMemory, MeteredCtx>, ctx)
273 + offset_of!(MeteredCtx, memory_ctx);
274 let page_indices_ptr_offset =
275 memory_ctx_offset + offset_of!(MemoryCtx<DEFAULT_PAGE_BITS>, page_indices);
276 let addr_space_access_count_ptr_offset =
277 memory_ctx_offset + offset_of!(MemoryCtx<DEFAULT_PAGE_BITS>, addr_space_access_count);
278 let page_indices_since_checkpoint_ptr_offset = memory_ctx_offset
279 + offset_of!(MemoryCtx<DEFAULT_PAGE_BITS>, page_indices_since_checkpoint);
280 let page_indices_since_checkpoint_len_offset = memory_ctx_offset
281 + offset_of!(
282 MemoryCtx<DEFAULT_PAGE_BITS>,
283 page_indices_since_checkpoint_len
284 );
285 let inserted_label = format!(".asm_execute_pc_{pc}_inserted");
286
287 // Append page_id to page_indices_since_checkpoint
288 asm_str += &format!(
289 " mov {reg1}, [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_len_offset}]\n"
290 );
291 asm_str += &format!(
292 " mov {reg2}, [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_ptr_offset}]\n"
293 );
294 let ptr_reg_32 = convert_x86_reg(ptr_reg, Width::W32).ok_or_else(|| {
295 AotError::Other(format!("unsupported ptr_reg for 32-bit store: {ptr_reg}"))
296 })?;
297 asm_str += &format!(" mov dword ptr [{reg2} + {reg1} * 4], {ptr_reg_32}\n");
298 asm_str += &format!(" add {reg1}, 1\n");
299 asm_str += &format!(
300 " mov [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_len_offset}], {reg1}\n"
301 );
302
303 // The next section is the implementation of `BitSet::insert` in ASM.
304 // pub fn insert(&mut self, index: usize) -> bool {
305 // let word_index = index >> 6;
306 // let bit_index = index & 63;
307 // let mask = 1u64 << bit_index;
308 // let word = unsafe { self.words.get_unchecked_mut(word_index) };
309 // let was_set = (*word & mask) != 0;
310 // *word |= mask;
311 // !was_set
312 // }
313
314 // Start with `ptr_reg = index`
315 // `reg1 = word_index`
316 asm_str += &format!(" mov {reg1}, {ptr_reg}\n");
317 asm_str += &format!(" shr {reg1}, 6\n");
318 // `ptr_reg = bit_index = index & 63`
319 asm_str += &format!(" and {ptr_reg}, 63\n");
320 // `reg2 = mask = 1u64 << bit_index`
321 asm_str += &format!(" mov {reg2}, 1\n");
322 asm_str += &format!(" shlx {reg2}, {reg2}, {ptr_reg}\n");
323 // `ptr_reg = self.page_indices.ptr`
324 asm_str +=
325 &format!(" mov {ptr_reg}, [{REG_EXEC_STATE_PTR} + {page_indices_ptr_offset}]\n");
326
327 // `reg1 = word_ptr = &self.words.get_unchecked_mut(word_index)`
328 asm_str += &format!(" lea {reg1}, [{ptr_reg} + {reg1} * 8]\n");
329 // `ptr_reg = word = *word_ptr`
330 asm_str += &format!(" mov {ptr_reg}, [{reg1}]\n");
331
332 // `test (*word & mask)`
333 asm_str += &format!(" test {ptr_reg}, {reg2}\n");
334 asm_str += &format!(" jnz {inserted_label}\n");
335 // When (*word & mask) == 0
336 // `*word += mask`
337 asm_str += &format!(" add {ptr_reg}, {reg2}\n");
338 asm_str += &format!(" mov [{reg1}], {ptr_reg}\n");
339 // reg1 = &addr_space_access_count.as_ptr()
340 asm_str += &format!(
341 " lea {reg1}, [{REG_EXEC_STATE_PTR} + {addr_space_access_count_ptr_offset}]\n"
342 );
343 asm_str += &format!(" mov {reg1}, [{reg1}]\n");
344 // self.addr_space_access_count[address_space] += 1;
345 asm_str += &format!(" add dword ptr [{reg1} + {address_space} * 4], 1\n");
346 asm_str += &format!("{inserted_label}:\n");
347 // Inserted, do nothing
348
349 Ok(asm_str)
350 }
351
352 /// Assumption: `REG_TRACE_HEIGHT` is the pointer of `trace_heights``.
353 pub(crate) fn update_height_change_asm(
354 chip_idx: usize,
355 height_delta: u32,
356 ) -> Result<String, AotError> {
357 let mut asm_str = String::new();
358 // `update_height_change_asm` rewrites the following code in ASM for `on_height_change`:
359 // ```
360 // pub fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
361 // self.trace_heights[chip_idx] += height_delta;
362 // }
363 // ```
364 asm_str +=
365 &format!(" add dword ptr [{REG_TRACE_HEIGHT} + {chip_idx} * 4], {height_delta}\n");
366 Ok(asm_str)
367 }
368}