openvm_rv32im_circuit/auipc/
core.rs

1use std::{
2    array::{self, from_fn},
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    AlignedBytesBorrow,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{
16    instruction::Instruction,
17    program::{DEFAULT_PC_STEP, PC_BITS},
18    LocalOpcode,
19};
20use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *};
21use openvm_stark_backend::{
22    interaction::InteractionBuilder,
23    p3_air::{AirBuilder, BaseAir},
24    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
25    rap::BaseAirWithPublicValues,
26};
27
28use crate::adapters::{
29    Rv32RdWriteAdapterExecutor, Rv32RdWriteAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
30};
31
32#[repr(C)]
33#[derive(Debug, Clone, AlignedBorrow)]
34pub struct Rv32AuipcCoreCols<T> {
35    pub is_valid: T,
36    // The limbs of the immediate except the least significant limb since it is always 0
37    pub imm_limbs: [T; RV32_REGISTER_NUM_LIMBS - 1],
38    // The limbs of the PC except the most significant and the least significant limbs
39    pub pc_limbs: [T; RV32_REGISTER_NUM_LIMBS - 2],
40    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
41}
42
43#[derive(Debug, Clone, Copy, derive_new::new)]
44pub struct Rv32AuipcCoreAir {
45    pub bus: BitwiseOperationLookupBus,
46}
47
48impl<F: Field> BaseAir<F> for Rv32AuipcCoreAir {
49    fn width(&self) -> usize {
50        Rv32AuipcCoreCols::<F>::width()
51    }
52}
53
54impl<F: Field> BaseAirWithPublicValues<F> for Rv32AuipcCoreAir {}
55
56impl<AB, I> VmCoreAir<AB, I> for Rv32AuipcCoreAir
57where
58    AB: InteractionBuilder,
59    I: VmAdapterInterface<AB::Expr>,
60    I::Reads: From<[[AB::Expr; 0]; 0]>,
61    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
62    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
63{
64    fn eval(
65        &self,
66        builder: &mut AB,
67        local_core: &[AB::Var],
68        from_pc: AB::Var,
69    ) -> AdapterAirContext<AB::Expr, I> {
70        let cols: &Rv32AuipcCoreCols<AB::Var> = (*local_core).borrow();
71
72        let Rv32AuipcCoreCols {
73            is_valid,
74            imm_limbs,
75            pc_limbs,
76            rd_data,
77        } = *cols;
78        builder.assert_bool(is_valid);
79
80        // We want to constrain rd = pc + imm (i32 add) where:
81        // - rd_data represents limbs of rd
82        // - pc_limbs are limbs of pc except the most and least significant limbs
83        // - imm_limbs are limbs of imm except the least significant limb
84
85        // We know that rd_data[0] is equal to the least significant limb of PC
86        // Thus, the intermediate value will be equal to PC without its most significant limb:
87        let intermed_val = rd_data[0]
88            + pc_limbs
89                .iter()
90                .enumerate()
91                .fold(AB::Expr::ZERO, |acc, (i, &val)| {
92                    acc + val * AB::Expr::from_u32(1 << ((i + 1) * RV32_CELL_BITS))
93                });
94
95        // Compute the most significant limb of PC
96        let pc_msl = (from_pc - intermed_val)
97            * AB::F::from_usize(1 << (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))).inverse();
98
99        // The vector pc_limbs contains the actual limbs of PC in little endian order
100        let pc_limbs = [rd_data[0]]
101            .iter()
102            .chain(pc_limbs.iter())
103            .map(|x| (*x).into())
104            .chain([pc_msl])
105            .collect::<Vec<AB::Expr>>();
106
107        let mut carry: [AB::Expr; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
108        let carry_divide = AB::F::from_usize(1 << RV32_CELL_BITS).inverse();
109
110        // Don't need to constrain the least significant limb of the addition
111        // since we already know that rd_data[0] = pc_limbs[0] and the least significant limb of imm
112        // is 0 Note: imm_limbs doesn't include the least significant limb so imm_limbs[i -
113        // 1] means the i-th limb of imm
114        for i in 1..RV32_REGISTER_NUM_LIMBS {
115            carry[i] = AB::Expr::from(carry_divide)
116                * (pc_limbs[i].clone() + imm_limbs[i - 1] - rd_data[i] + carry[i - 1].clone());
117            builder.when(is_valid).assert_bool(carry[i].clone());
118        }
119
120        // Range checking of rd_data entries to RV32_CELL_BITS bits
121        for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
122            self.bus
123                .send_range(rd_data[i * 2], rd_data[i * 2 + 1])
124                .eval(builder, is_valid);
125        }
126
127        // The immediate and PC limbs need range checking to ensure they're within [0,
128        // 2^RV32_CELL_BITS) Since we range check two items at a time, doing this way helps
129        // efficiently divide the limbs into groups of 2 Note: range checking the limbs of
130        // immediate and PC separately would result in additional range checks       since
131        // they both have odd number of limbs that need to be range checked
132        let mut need_range_check: Vec<AB::Expr> = Vec::new();
133        for limb in imm_limbs {
134            need_range_check.push(limb.into());
135        }
136
137        assert_eq!(pc_limbs.len(), RV32_REGISTER_NUM_LIMBS);
138        // use enumerate to match pc_limbs[0] => i = 0, pc_limbs[1] => i = 1, ...
139        // pc_limbs[0] is already range checked through rd_data[0], so we skip it
140        for (i, limb) in pc_limbs.iter().enumerate().skip(1) {
141            // the most significant limb is pc_limbs[3] => i = 3
142            if i == pc_limbs.len() - 1 {
143                // Range check the most significant limb of pc to be in [0,
144                // 2^{PC_BITS-(RV32_REGISTER_NUM_LIMBS-1)*RV32_CELL_BITS})
145                need_range_check.push(
146                    (*limb).clone()
147                        * AB::Expr::from_usize(1 << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS)),
148                );
149            } else {
150                need_range_check.push((*limb).clone());
151            }
152        }
153
154        // need_range_check contains (RV32_REGISTER_NUM_LIMBS - 1) elements from imm_limbs
155        // and (RV32_REGISTER_NUM_LIMBS - 1) elements from pc_limbs
156        // Hence, is of even length 2*RV32_REGISTER_NUM_LIMBS - 2
157        assert_eq!(need_range_check.len() % 2, 0);
158        for pair in need_range_check.chunks_exact(2) {
159            self.bus
160                .send_range(pair[0].clone(), pair[1].clone())
161                .eval(builder, is_valid);
162        }
163
164        let imm = imm_limbs
165            .iter()
166            .enumerate()
167            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
168                acc + val * AB::Expr::from_u32(1 << (i * RV32_CELL_BITS))
169            });
170        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, AUIPC);
171        AdapterAirContext {
172            to_pc: None,
173            reads: [].into(),
174            writes: [rd_data.map(|x| x.into())].into(),
175            instruction: ImmInstruction {
176                is_valid: is_valid.into(),
177                opcode: expected_opcode,
178                immediate: imm,
179            }
180            .into(),
181        }
182    }
183
184    fn start_offset(&self) -> usize {
185        Rv32AuipcOpcode::CLASS_OFFSET
186    }
187}
188
189#[repr(C)]
190#[derive(AlignedBytesBorrow, Debug, Clone)]
191pub struct Rv32AuipcCoreRecord {
192    pub from_pc: u32,
193    pub imm: u32,
194}
195
196#[derive(Clone, Copy, derive_new::new)]
197pub struct Rv32AuipcExecutor<A = Rv32RdWriteAdapterExecutor> {
198    adapter: A,
199}
200
201#[derive(Clone, derive_new::new)]
202pub struct Rv32AuipcFiller<A = Rv32RdWriteAdapterFiller> {
203    adapter: A,
204    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
205}
206
207impl<F, A, RA> PreflightExecutor<F, RA> for Rv32AuipcExecutor<A>
208where
209    F: PrimeField32,
210    A: 'static + AdapterTraceExecutor<F, ReadData = (), WriteData = [u8; RV32_REGISTER_NUM_LIMBS]>,
211    for<'buf> RA: RecordArena<
212        'buf,
213        EmptyAdapterCoreLayout<F, A>,
214        (A::RecordMut<'buf>, &'buf mut Rv32AuipcCoreRecord),
215    >,
216{
217    fn get_opcode_name(&self, _: usize) -> String {
218        format!("{AUIPC:?}")
219    }
220
221    fn execute(
222        &self,
223        state: VmStateMut<F, TracingMemory, RA>,
224        instruction: &Instruction<F>,
225    ) -> Result<(), ExecutionError> {
226        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
227
228        A::start(*state.pc, state.memory, &mut adapter_record);
229
230        core_record.from_pc = *state.pc;
231        core_record.imm = instruction.c.as_canonical_u32();
232
233        let rd = run_auipc(*state.pc, core_record.imm);
234
235        self.adapter
236            .write(state.memory, instruction, rd, &mut adapter_record);
237
238        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
239
240        Ok(())
241    }
242}
243
244impl<F, A> TraceFiller<F> for Rv32AuipcFiller<A>
245where
246    F: PrimeField32,
247    A: 'static + AdapterTraceFiller<F>,
248{
249    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
250        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
251        // Rv32AuipcCoreCols::width() elements
252        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
253        self.adapter.fill_trace_row(mem_helper, adapter_row);
254        // SAFETY: core_row contains a valid Rv32AuipcCoreRecord written by the executor
255        // during trace generation
256        let record: &Rv32AuipcCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) };
257
258        let core_row: &mut Rv32AuipcCoreCols<F> = core_row.borrow_mut();
259
260        let imm_limbs = record.imm.to_le_bytes();
261        let pc_limbs = record.from_pc.to_le_bytes();
262        let rd_data = run_auipc(record.from_pc, record.imm);
263        debug_assert_eq!(imm_limbs[3], 0);
264
265        // range checks:
266        // hardcoding for performance: first 3 limbs of imm_limbs, last 3 limbs of pc_limbs where
267        // most significant limb of pc_limbs is shifted up
268        self.bitwise_lookup_chip
269            .request_range(imm_limbs[0] as u32, imm_limbs[1] as u32);
270        self.bitwise_lookup_chip
271            .request_range(imm_limbs[2] as u32, pc_limbs[1] as u32);
272        let msl_shift = RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - PC_BITS;
273        self.bitwise_lookup_chip
274            .request_range(pc_limbs[2] as u32, (pc_limbs[3] as u32) << msl_shift);
275        for pair in rd_data.chunks_exact(2) {
276            self.bitwise_lookup_chip
277                .request_range(pair[0] as u32, pair[1] as u32);
278        }
279        // Writing in reverse order
280        core_row.rd_data = rd_data.map(F::from_u8);
281        // only the middle 2 limbs:
282        core_row.pc_limbs = from_fn(|i| F::from_u8(pc_limbs[i + 1]));
283        core_row.imm_limbs = from_fn(|i| F::from_u8(imm_limbs[i]));
284
285        core_row.is_valid = F::ONE;
286    }
287}
288
289// returns rd_data
290#[inline(always)]
291pub(super) fn run_auipc(pc: u32, imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] {
292    let rd = pc.wrapping_add(imm << RV32_CELL_BITS);
293    rd.to_le_bytes()
294}