openvm_rv32im_circuit/common/
mod.rs

1#[cfg(feature = "aot")]
2pub(crate) use aot::*;
3
4#[cfg(feature = "aot")]
5mod aot {
6    use std::mem::offset_of;
7
8    pub(crate) use openvm_circuit::arch::aot::common::*;
9    use openvm_circuit::{
10        arch::{
11            execution_mode::{metered::memory_ctx::MemoryCtx, MeteredCtx},
12            AotError, SystemConfig, VmExecState, ADDR_SPACE_OFFSET,
13        },
14        system::memory::{merkle::public_values::PUBLIC_VALUES_AS, online::GuestMemory, CHUNK},
15    };
16    use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS};
17
18    /// The minimum block size is 4, but RISC-V `lb` only requires alignment of 1 and `lh` only
19    /// requires alignment of 2 because the instructions are implemented by doing an access of
20    /// block size 4.
21    const DEFAULT_U8_BLOCK_SIZE_BITS: u8 = 2;
22    /// This is DIRTY because PAGE_BITS is a generic parameter of E2 context.
23    const DEFAULT_PAGE_BITS: usize = 6;
24
25    pub(crate) fn gpr_to_rv32_register(gpr: &str, rv32_reg: u8) -> String {
26        let xmm_map_reg = rv32_reg / 2;
27        if rv32_reg % 2 == 0 {
28            format!("   pinsrd xmm{xmm_map_reg}, {gpr}, 0\n")
29        } else {
30            format!("   pinsrd xmm{xmm_map_reg}, {gpr}, 1\n")
31        }
32    }
33
34    pub(crate) fn address_space_start_to_gpr(address_space: u32, gpr: &str) -> String {
35        if address_space == 2 {
36            if REG_AS2_PTR != gpr {
37                return format!("    mov {gpr}, r15\n");
38            }
39            return "".to_string();
40        }
41
42        let xmm_map_reg = match address_space {
43            1 => "xmm0",
44            3 => "xmm1",
45            4 => "xmm2",
46            _ => unreachable!("Only address space 1, 2, 3, 4 is supported"),
47        };
48        format!("   pextrq {gpr}, {xmm_map_reg}, 1\n")
49    }
50
51    /*
52    input:
53    - riscv register number
54    - gpr register to write into
55    - is_gpr_force_write boolean
56
57    output:
58    - string representing the general purpose register that stores the value of register number `rv32_reg`
59    - emitted assembly string that performs the move
60    */
61    pub(crate) fn xmm_to_gpr(
62        rv32_reg: u8,
63        gpr: &str,
64        is_gpr_force_write: bool,
65    ) -> (String, String) {
66        if let Some(override_reg) = RISCV_TO_X86_OVERRIDE_MAP[rv32_reg as usize] {
67            // a/4 is overridden, b/4 is overridden
68            if is_gpr_force_write {
69                return (gpr.to_string(), format!("  mov {gpr}, {override_reg}\n"));
70            }
71            return (override_reg.to_string(), "".to_string());
72        }
73        let xmm_map_reg = rv32_reg / 2;
74        if rv32_reg % 2 == 0 {
75            (
76                gpr.to_string(),
77                format!("   pextrd {gpr}, xmm{xmm_map_reg}, 0\n"),
78            )
79        } else {
80            (
81                gpr.to_string(),
82                format!("   pextrd {gpr}, xmm{xmm_map_reg}, 1\n"),
83            )
84        }
85    }
86
87    pub(crate) fn gpr_to_xmm(gpr: &str, rv32_reg: u8) -> String {
88        if let Some(override_reg) = RISCV_TO_X86_OVERRIDE_MAP[rv32_reg as usize] {
89            if gpr == override_reg {
90                //already in correct location
91                return "".to_string();
92            }
93            return format!("   mov {override_reg}, {gpr}\n");
94        }
95        let xmm_map_reg = rv32_reg / 2;
96        if rv32_reg % 2 == 0 {
97            format!("   pinsrd xmm{xmm_map_reg}, {gpr}, 0\n")
98        } else {
99            format!("   pinsrd xmm{xmm_map_reg}, {gpr}, 1\n")
100        }
101    }
102    pub(crate) fn update_adapter_heights_asm(
103        config: &SystemConfig,
104        _address_space: u32,
105    ) -> Result<String, AotError> {
106        let min_block_size_bits = config.memory_config.min_block_size_bits();
107        if min_block_size_bits[RV32_REGISTER_AS as usize] != DEFAULT_U8_BLOCK_SIZE_BITS {
108            println!("RV32_REGISTER_AS must have a minimum block size of 4");
109            return Err(AotError::Other(String::from(
110                "RV32_REGISTER_AS must have a minimum block size of 4",
111            )));
112        }
113        if min_block_size_bits[RV32_MEMORY_AS as usize] != DEFAULT_U8_BLOCK_SIZE_BITS {
114            println!("RV32_MEMORY_AS must have a minimum block size of 4");
115            return Err(AotError::Other(String::from(
116                "RV32_MEMORY_AS must have a minimum block size of 4",
117            )));
118        }
119        if min_block_size_bits[PUBLIC_VALUES_AS as usize] != DEFAULT_U8_BLOCK_SIZE_BITS {
120            println!("PUBLIC_VALUES_AS must have a minimum block size of 4");
121            return Err(AotError::Other(String::from(
122                "PUBLIC_VALUES_AS must have a minimum block size of 4",
123            )));
124        }
125
126        // `update_adapter_heights_asm` rewrites the following code in ASM for
127        // `on_memory_operation`: ```
128        // pub fn update_adapter_heights_batch(
129        //     &self,
130        //     trace_heights: &mut [u32],
131        //     address_space: u32,
132        //     size_bits: u32,
133        //     num: u32,
134        // ) {
135        //     let align_bits = unsafe {
136        //         *self
137        //             .min_block_size_bits
138        //             .get_unchecked(address_space as usize)
139        //     };
140        //
141        //     for adapter_bits in (align_bits as u32 + 1..=size_bits).rev() {
142        //         let adapter_idx = self.adapter_offset + adapter_bits as usize - 1;
143        //         debug_assert!(adapter_idx < trace_heights.len());
144        //         unsafe {
145        //             *trace_heights.get_unchecked_mut(adapter_idx) +=
146        //                 num << (size_bits - adapter_bits + 1);
147        //         }
148        //     }
149        // }
150        // ```
151        // 
152        // For a specific RV32 instruction, the variables can be treated as constants at AOT
153        // compilation time:
154        // - `address_space`: always a constant because it is derived from an Instruction
155        // - `num`: always 1 in `on_memory_operation`
156        // - `align_bits`: always a constant because `address_space` is a constant
157        // - `size_bits`: RV32 instruction always read 4 bytes(in the AIR level). So `size` is
158        //   always 4 bytes. So `size_bits` is always 2.
159        //
160        // If we ignore the Native address space, `min_block_size_bits`` is always
161        // `DEFAULT_U8_BLOCK_SIZE=4`. Therefore, `align_bits` is always 2. So the loop will
162        // never be executed and we can leave the function empty.
163        Ok("".to_string())
164    }
165
166    /// Generate ASM code for updating the boundary merkle heights.
167    ///
168    /// # Arguments
169    ///
170    /// * `config` - The system configuration.
171    /// * `address_space` - The address space.
172    /// * `pc` - The program counter of the current instruction.
173    /// * `ptr_reg` - The register to store the accessed pointer. The caller should not expect the
174    ///   value of this register to be preserved.
175    /// * `reg1` - A register to store the intermediate result.
176    /// * `reg2` - A register to store the intermediate result.
177    ///
178    /// # Returns
179    ///
180    /// The ASM code for updating the boundary merkle heights.
181    pub(crate) fn update_boundary_merkle_heights_asm<F>(
182        config: &SystemConfig,
183        address_space: u32,
184        pc: u32,
185        ptr_reg: &str,
186        reg1: &str,
187        reg2: &str,
188    ) -> Result<String, AotError> {
189        // `update_boundary_merkle_heights_asm` rewrites the following code in ASM for
190        // `on_memory_operation`: ```
191        // pub fn label_to_index((addr_space, block_id): (u32, u32)) -> u64 {
192        //     (((addr_space - ADDR_SPACE_OFFSET) as u64) << self.address_height) + block_id as u64
193        // }
194        //
195        // pub(crate) fn update_boundary_merkle_heights(
196        //     &mut self,
197        //     address_space: u32,
198        //     ptr: u32,
199        //     size: u32,
200        // ) {
201        //     let num_blocks = (size + self.chunk - 1) >> self.chunk_bits;
202        //     let start_chunk_id = ptr >> self.chunk_bits;
203        //     let start_block_id = if self.chunk == 1 {
204        //         start_chunk_id
205        //     } else {
206        //         self.memory_dimensions
207        //             .label_to_index((address_space, start_chunk_id)) as u32
208        //     };
209        //     // Because `self.chunk == 1 << self.chunk_bits`
210        //     let end_block_id = start_block_id + num_blocks;
211        //     let start_page_id = start_block_id >> PAGE_BITS;
212        //     let end_page_id = ((end_block_id - 1) >> PAGE_BITS) + 1;
213
214        //     for page_id in start_page_id..end_page_id {
215        //          // Append page_id to page_indices_since_checkpoint
216        //          let len = self.page_indices_since_checkpoint_len;
217        //          // SAFETY: len is within bounds, and we extend length by 1 after writing.
218        //          unsafe {
219        //              *self.page_indices_since_checkpoint.as_mut_ptr().add(len) = page_id;
220        //          }
221        //          self.page_indices_since_checkpoint_len = len + 1;
222        //
223        //         if self.page_indices.insert(page_id as usize) {
224        //             // SAFETY: address_space passed is usually a hardcoded constant or derived
225        // from an             // Instruction where it is bounds checked before passing
226        //             unsafe {
227        //                 *self
228        //                     .addr_space_access_count
229        //                     .get_unchecked_mut(address_space as usize) += 1;
230        //             }
231        //         }
232        //     }
233        // }
234        // ```
235        // 
236        // For a specific RV32 instruction, the variables can be treated as constants at AOT compilation time:
237        // Inputs:
238        // - `chunk`: always 8(CHUNK) because we only support when continuation is enabled.
239        // - `address_space`: always a constant because it is derived from an Instruction
240        // - `size`: RV32 instruction always read 4 bytes(in the AIR level).
241        // - `self.memory_dimensions.address_height`: known at AOT compilation time because it is derived from the memory configuration.
242        // Inside the function body:
243        // - `num_blocks`: `(size + self.chunk - 1) >> self.chunk_bits = (4 + 8 - 1) >> 3 = 1`
244        // - `as_offset = (addr_space - ADDR_SPACE_OFFSET) as u64) << self.address_height)`: constant because `address_space` and `address_height` constant
245        // - `start_chunk_id`: `ptr >> self.chunk_bits`
246        // - `start_block_id`: `start_chunk_id + as_offset`
247        // - `end_block_id`: `start_block_id + num_blocks = start_block_id +1`
248        // - `start_page_id`: `start_block_id >> PAGE_BITS`
249        // - `end_page_id`: ((end_block_id - 1) >> PAGE_BITS) + 1 = start_block_id >> PAGE_BITS + 1;
250        //
251        // Therefore the loop only iterates once for `page_id = start_page_id`.
252
253        let initial_block_size: usize = config.initial_block_size();
254        if initial_block_size != CHUNK {
255            return Err(AotError::Other(format!(
256                "initial_block_size must be {CHUNK}, got {initial_block_size}"
257            )));
258        }
259        let chunk_bits = CHUNK.ilog2();
260        let as_offset = ((address_space - ADDR_SPACE_OFFSET) as u64)
261            << (config.memory_config.memory_dimensions().address_height);
262
263        let mut asm_str = String::new();
264        // `start_chunk_id`: `ptr >> self.chunk_bits`
265        asm_str += &format!("    shr {ptr_reg}, {chunk_bits}\n");
266        // `start_block_id`: `start_chunk_id + as_offset`
267        asm_str += &format!("    add {ptr_reg}, {as_offset}\n");
268        // `start_page_id`: `start_block_id >> PAGE_BITS`
269        // NOTE: This is DIRTY because PAGE_BITS is a generic parameter of E2 context.
270        asm_str += &format!("    shr {ptr_reg}, {DEFAULT_PAGE_BITS}\n");
271
272        let memory_ctx_offset = offset_of!(VmExecState<F, GuestMemory, MeteredCtx>, ctx)
273            + offset_of!(MeteredCtx, memory_ctx);
274        let page_indices_ptr_offset =
275            memory_ctx_offset + offset_of!(MemoryCtx<DEFAULT_PAGE_BITS>, page_indices);
276        let addr_space_access_count_ptr_offset =
277            memory_ctx_offset + offset_of!(MemoryCtx<DEFAULT_PAGE_BITS>, addr_space_access_count);
278        let page_indices_since_checkpoint_ptr_offset = memory_ctx_offset
279            + offset_of!(MemoryCtx<DEFAULT_PAGE_BITS>, page_indices_since_checkpoint);
280        let page_indices_since_checkpoint_len_offset = memory_ctx_offset
281            + offset_of!(
282                MemoryCtx<DEFAULT_PAGE_BITS>,
283                page_indices_since_checkpoint_len
284            );
285        let inserted_label = format!(".asm_execute_pc_{pc}_inserted");
286
287        // Append page_id to page_indices_since_checkpoint
288        asm_str += &format!(
289            "    mov {reg1}, [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_len_offset}]\n"
290        );
291        asm_str += &format!(
292            "    mov {reg2}, [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_ptr_offset}]\n"
293        );
294        let ptr_reg_32 = convert_x86_reg(ptr_reg, Width::W32).ok_or_else(|| {
295            AotError::Other(format!("unsupported ptr_reg for 32-bit store: {ptr_reg}"))
296        })?;
297        asm_str += &format!("    mov dword ptr [{reg2} + {reg1} * 4], {ptr_reg_32}\n");
298        asm_str += &format!("    add {reg1}, 1\n");
299        asm_str += &format!(
300            "    mov [{REG_EXEC_STATE_PTR} + {page_indices_since_checkpoint_len_offset}], {reg1}\n"
301        );
302
303        // The next section is the implementation of `BitSet::insert` in ASM.
304        // pub fn insert(&mut self, index: usize) -> bool {
305        //     let word_index = index >> 6;
306        //     let bit_index = index & 63;
307        //     let mask = 1u64 << bit_index;
308        //     let word = unsafe { self.words.get_unchecked_mut(word_index) };
309        //     let was_set = (*word & mask) != 0;
310        //     *word |= mask;
311        //     !was_set
312        // }
313
314        // Start with `ptr_reg = index`
315        // `reg1 = word_index`
316        asm_str += &format!("    mov {reg1}, {ptr_reg}\n");
317        asm_str += &format!("    shr {reg1}, 6\n");
318        // `ptr_reg = bit_index = index & 63`
319        asm_str += &format!("    and {ptr_reg}, 63\n");
320        // `reg2 = mask = 1u64 << bit_index`
321        asm_str += &format!("    mov {reg2}, 1\n");
322        asm_str += &format!("    shlx {reg2}, {reg2}, {ptr_reg}\n");
323        // `ptr_reg = self.page_indices.ptr`
324        asm_str +=
325            &format!("    mov {ptr_reg}, [{REG_EXEC_STATE_PTR} + {page_indices_ptr_offset}]\n");
326
327        // `reg1 = word_ptr = &self.words.get_unchecked_mut(word_index)`
328        asm_str += &format!("    lea {reg1}, [{ptr_reg} + {reg1} * 8]\n");
329        // `ptr_reg = word = *word_ptr`
330        asm_str += &format!("    mov {ptr_reg}, [{reg1}]\n");
331
332        // `test (*word & mask)`
333        asm_str += &format!("    test {ptr_reg}, {reg2}\n");
334        asm_str += &format!("    jnz {inserted_label}\n");
335        // When (*word & mask) == 0
336        // `*word += mask`
337        asm_str += &format!("    add {ptr_reg}, {reg2}\n");
338        asm_str += &format!("    mov [{reg1}], {ptr_reg}\n");
339        // reg1 = &addr_space_access_count.as_ptr()
340        asm_str += &format!(
341            "    lea {reg1}, [{REG_EXEC_STATE_PTR} + {addr_space_access_count_ptr_offset}]\n"
342        );
343        asm_str += &format!("    mov {reg1}, [{reg1}]\n");
344        // self.addr_space_access_count[address_space] += 1;
345        asm_str += &format!("    add dword ptr [{reg1} + {address_space} * 4], 1\n");
346        asm_str += &format!("{inserted_label}:\n");
347        // Inserted, do nothing
348
349        Ok(asm_str)
350    }
351
352    /// Assumption: `REG_TRACE_HEIGHT` is the pointer of `trace_heights``.
353    pub(crate) fn update_height_change_asm(
354        chip_idx: usize,
355        height_delta: u32,
356    ) -> Result<String, AotError> {
357        let mut asm_str = String::new();
358        // `update_height_change_asm` rewrites the following code in ASM for `on_height_change`:
359        // ```
360        // pub fn on_height_change(&mut self, chip_idx: usize, height_delta: u32) {
361        //     self.trace_heights[chip_idx] += height_delta;
362        // }
363        // ```
364        asm_str +=
365            &format!("    add dword ptr [{REG_TRACE_HEIGHT} + {chip_idx} * 4], {height_delta}\n");
366        Ok(asm_str)
367    }
368}