openvm_rv32im_circuit/auipc/
core.rs

1use std::{
2    array::{self, from_fn},
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    AlignedBytesBorrow,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{
16    instruction::Instruction,
17    program::{DEFAULT_PC_STEP, PC_BITS},
18    LocalOpcode,
19};
20use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *};
21use openvm_stark_backend::{
22    interaction::InteractionBuilder,
23    p3_air::{AirBuilder, BaseAir},
24    p3_field::{Field, FieldAlgebra, PrimeField32},
25    rap::BaseAirWithPublicValues,
26};
27
28use crate::adapters::{
29    Rv32RdWriteAdapterExecutor, Rv32RdWriteAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
30};
31
32#[repr(C)]
33#[derive(Debug, Clone, AlignedBorrow)]
34pub struct Rv32AuipcCoreCols<T> {
35    pub is_valid: T,
36    // The limbs of the immediate except the least significant limb since it is always 0
37    pub imm_limbs: [T; RV32_REGISTER_NUM_LIMBS - 1],
38    // The limbs of the PC except the most significant and the least significant limbs
39    pub pc_limbs: [T; RV32_REGISTER_NUM_LIMBS - 2],
40    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
41}
42
43#[derive(Debug, Clone, Copy, derive_new::new)]
44pub struct Rv32AuipcCoreAir {
45    pub bus: BitwiseOperationLookupBus,
46}
47
48impl<F: Field> BaseAir<F> for Rv32AuipcCoreAir {
49    fn width(&self) -> usize {
50        Rv32AuipcCoreCols::<F>::width()
51    }
52}
53
54impl<F: Field> BaseAirWithPublicValues<F> for Rv32AuipcCoreAir {}
55
56impl<AB, I> VmCoreAir<AB, I> for Rv32AuipcCoreAir
57where
58    AB: InteractionBuilder,
59    I: VmAdapterInterface<AB::Expr>,
60    I::Reads: From<[[AB::Expr; 0]; 0]>,
61    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
62    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
63{
64    fn eval(
65        &self,
66        builder: &mut AB,
67        local_core: &[AB::Var],
68        from_pc: AB::Var,
69    ) -> AdapterAirContext<AB::Expr, I> {
70        let cols: &Rv32AuipcCoreCols<AB::Var> = (*local_core).borrow();
71
72        let Rv32AuipcCoreCols {
73            is_valid,
74            imm_limbs,
75            pc_limbs,
76            rd_data,
77        } = *cols;
78        builder.assert_bool(is_valid);
79
80        // We want to constrain rd = pc + imm (i32 add) where:
81        // - rd_data represents limbs of rd
82        // - pc_limbs are limbs of pc except the most and least significant limbs
83        // - imm_limbs are limbs of imm except the least significant limb
84
85        // We know that rd_data[0] is equal to the least significant limb of PC
86        // Thus, the intermediate value will be equal to PC without its most significant limb:
87        let intermed_val = rd_data[0]
88            + pc_limbs
89                .iter()
90                .enumerate()
91                .fold(AB::Expr::ZERO, |acc, (i, &val)| {
92                    acc + val * AB::Expr::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS))
93                });
94
95        // Compute the most significant limb of PC
96        let pc_msl = (from_pc - intermed_val)
97            * AB::F::from_canonical_usize(1 << (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
98                .inverse();
99
100        // The vector pc_limbs contains the actual limbs of PC in little endian order
101        let pc_limbs = [rd_data[0]]
102            .iter()
103            .chain(pc_limbs.iter())
104            .map(|x| (*x).into())
105            .chain([pc_msl])
106            .collect::<Vec<AB::Expr>>();
107
108        let mut carry: [AB::Expr; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
109        let carry_divide = AB::F::from_canonical_usize(1 << RV32_CELL_BITS).inverse();
110
111        // Don't need to constrain the least significant limb of the addition
112        // since we already know that rd_data[0] = pc_limbs[0] and the least significant limb of imm
113        // is 0 Note: imm_limbs doesn't include the least significant limb so imm_limbs[i -
114        // 1] means the i-th limb of imm
115        for i in 1..RV32_REGISTER_NUM_LIMBS {
116            carry[i] = AB::Expr::from(carry_divide)
117                * (pc_limbs[i].clone() + imm_limbs[i - 1] - rd_data[i] + carry[i - 1].clone());
118            builder.when(is_valid).assert_bool(carry[i].clone());
119        }
120
121        // Range checking of rd_data entries to RV32_CELL_BITS bits
122        for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
123            self.bus
124                .send_range(rd_data[i * 2], rd_data[i * 2 + 1])
125                .eval(builder, is_valid);
126        }
127
128        // The immediate and PC limbs need range checking to ensure they're within [0,
129        // 2^RV32_CELL_BITS) Since we range check two items at a time, doing this way helps
130        // efficiently divide the limbs into groups of 2 Note: range checking the limbs of
131        // immediate and PC separately would result in additional range checks       since
132        // they both have odd number of limbs that need to be range checked
133        let mut need_range_check: Vec<AB::Expr> = Vec::new();
134        for limb in imm_limbs {
135            need_range_check.push(limb.into());
136        }
137
138        assert_eq!(pc_limbs.len(), RV32_REGISTER_NUM_LIMBS);
139        // use enumerate to match pc_limbs[0] => i = 0, pc_limbs[1] => i = 1, ...
140        // pc_limbs[0] is already range checked through rd_data[0], so we skip it
141        for (i, limb) in pc_limbs.iter().enumerate().skip(1) {
142            // the most significant limb is pc_limbs[3] => i = 3
143            if i == pc_limbs.len() - 1 {
144                // Range check the most significant limb of pc to be in [0,
145                // 2^{PC_BITS-(RV32_REGISTER_NUM_LIMBS-1)*RV32_CELL_BITS})
146                need_range_check.push(
147                    (*limb).clone()
148                        * AB::Expr::from_canonical_usize(
149                            1 << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS),
150                        ),
151                );
152            } else {
153                need_range_check.push((*limb).clone());
154            }
155        }
156
157        // need_range_check contains (RV32_REGISTER_NUM_LIMBS - 1) elements from imm_limbs
158        // and (RV32_REGISTER_NUM_LIMBS - 1) elements from pc_limbs
159        // Hence, is of even length 2*RV32_REGISTER_NUM_LIMBS - 2
160        assert_eq!(need_range_check.len() % 2, 0);
161        for pair in need_range_check.chunks_exact(2) {
162            self.bus
163                .send_range(pair[0].clone(), pair[1].clone())
164                .eval(builder, is_valid);
165        }
166
167        let imm = imm_limbs
168            .iter()
169            .enumerate()
170            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
171                acc + val * AB::Expr::from_canonical_u32(1 << (i * RV32_CELL_BITS))
172            });
173        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, AUIPC);
174        AdapterAirContext {
175            to_pc: None,
176            reads: [].into(),
177            writes: [rd_data.map(|x| x.into())].into(),
178            instruction: ImmInstruction {
179                is_valid: is_valid.into(),
180                opcode: expected_opcode,
181                immediate: imm,
182            }
183            .into(),
184        }
185    }
186
187    fn start_offset(&self) -> usize {
188        Rv32AuipcOpcode::CLASS_OFFSET
189    }
190}
191
192#[repr(C)]
193#[derive(AlignedBytesBorrow, Debug, Clone)]
194pub struct Rv32AuipcCoreRecord {
195    pub from_pc: u32,
196    pub imm: u32,
197}
198
199#[derive(Clone, Copy, derive_new::new)]
200pub struct Rv32AuipcExecutor<A = Rv32RdWriteAdapterExecutor> {
201    adapter: A,
202}
203
204#[derive(Clone, derive_new::new)]
205pub struct Rv32AuipcFiller<A = Rv32RdWriteAdapterFiller> {
206    adapter: A,
207    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
208}
209
210impl<F, A, RA> PreflightExecutor<F, RA> for Rv32AuipcExecutor<A>
211where
212    F: PrimeField32,
213    A: 'static + AdapterTraceExecutor<F, ReadData = (), WriteData = [u8; RV32_REGISTER_NUM_LIMBS]>,
214    for<'buf> RA: RecordArena<
215        'buf,
216        EmptyAdapterCoreLayout<F, A>,
217        (A::RecordMut<'buf>, &'buf mut Rv32AuipcCoreRecord),
218    >,
219{
220    fn get_opcode_name(&self, _: usize) -> String {
221        format!("{:?}", AUIPC)
222    }
223
224    fn execute(
225        &self,
226        state: VmStateMut<F, TracingMemory, RA>,
227        instruction: &Instruction<F>,
228    ) -> Result<(), ExecutionError> {
229        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
230
231        A::start(*state.pc, state.memory, &mut adapter_record);
232
233        core_record.from_pc = *state.pc;
234        core_record.imm = instruction.c.as_canonical_u32();
235
236        let rd = run_auipc(*state.pc, core_record.imm);
237
238        self.adapter
239            .write(state.memory, instruction, rd, &mut adapter_record);
240
241        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
242
243        Ok(())
244    }
245}
246
247impl<F, A> TraceFiller<F> for Rv32AuipcFiller<A>
248where
249    F: PrimeField32,
250    A: 'static + AdapterTraceFiller<F>,
251{
252    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
253        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
254        // Rv32AuipcCoreCols::width() elements
255        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
256        self.adapter.fill_trace_row(mem_helper, adapter_row);
257        // SAFETY: core_row contains a valid Rv32AuipcCoreRecord written by the executor
258        // during trace generation
259        let record: &Rv32AuipcCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) };
260
261        let core_row: &mut Rv32AuipcCoreCols<F> = core_row.borrow_mut();
262
263        let imm_limbs = record.imm.to_le_bytes();
264        let pc_limbs = record.from_pc.to_le_bytes();
265        let rd_data = run_auipc(record.from_pc, record.imm);
266        debug_assert_eq!(imm_limbs[3], 0);
267
268        // range checks:
269        // hardcoding for performance: first 3 limbs of imm_limbs, last 3 limbs of pc_limbs where
270        // most significant limb of pc_limbs is shifted up
271        self.bitwise_lookup_chip
272            .request_range(imm_limbs[0] as u32, imm_limbs[1] as u32);
273        self.bitwise_lookup_chip
274            .request_range(imm_limbs[2] as u32, pc_limbs[1] as u32);
275        let msl_shift = RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - PC_BITS;
276        self.bitwise_lookup_chip
277            .request_range(pc_limbs[2] as u32, (pc_limbs[3] as u32) << msl_shift);
278        for pair in rd_data.chunks_exact(2) {
279            self.bitwise_lookup_chip
280                .request_range(pair[0] as u32, pair[1] as u32);
281        }
282        // Writing in reverse order
283        core_row.rd_data = rd_data.map(F::from_canonical_u8);
284        // only the middle 2 limbs:
285        core_row.pc_limbs = from_fn(|i| F::from_canonical_u8(pc_limbs[i + 1]));
286        core_row.imm_limbs = from_fn(|i| F::from_canonical_u8(imm_limbs[i]));
287
288        core_row.is_valid = F::ONE;
289    }
290}
291
292// returns rd_data
293#[inline(always)]
294pub(super) fn run_auipc(pc: u32, imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] {
295    let rd = pc.wrapping_add(imm << RV32_CELL_BITS);
296    rd.to_le_bytes()
297}