openvm_rv32im_circuit/auipc/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::bitwise_op_lookup::{
11    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
12};
13use openvm_circuit_primitives_derive::AlignedBorrow;
14use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode};
15use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *};
16use openvm_stark_backend::{
17    interaction::InteractionBuilder,
18    p3_air::{AirBuilder, BaseAir},
19    p3_field::{Field, FieldAlgebra, PrimeField32},
20    rap::BaseAirWithPublicValues,
21};
22use serde::{Deserialize, Serialize};
23
24use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
25
26const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1;
27
28#[repr(C)]
29#[derive(Debug, Clone, AlignedBorrow)]
30pub struct Rv32AuipcCoreCols<T> {
31    pub is_valid: T,
32    // The limbs of the immediate except the least significant limb since it is always 0
33    pub imm_limbs: [T; RV32_REGISTER_NUM_LIMBS - 1],
34    // The limbs of the PC except the most significant and the least significant limbs
35    pub pc_limbs: [T; RV32_REGISTER_NUM_LIMBS - 2],
36    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
37}
38
39#[derive(Debug, Clone)]
40pub struct Rv32AuipcCoreAir {
41    pub bus: BitwiseOperationLookupBus,
42}
43
44impl<F: Field> BaseAir<F> for Rv32AuipcCoreAir {
45    fn width(&self) -> usize {
46        Rv32AuipcCoreCols::<F>::width()
47    }
48}
49
50impl<F: Field> BaseAirWithPublicValues<F> for Rv32AuipcCoreAir {}
51
52impl<AB, I> VmCoreAir<AB, I> for Rv32AuipcCoreAir
53where
54    AB: InteractionBuilder,
55    I: VmAdapterInterface<AB::Expr>,
56    I::Reads: From<[[AB::Expr; 0]; 0]>,
57    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
58    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
59{
60    fn eval(
61        &self,
62        builder: &mut AB,
63        local_core: &[AB::Var],
64        from_pc: AB::Var,
65    ) -> AdapterAirContext<AB::Expr, I> {
66        let cols: &Rv32AuipcCoreCols<AB::Var> = (*local_core).borrow();
67
68        let Rv32AuipcCoreCols {
69            is_valid,
70            imm_limbs,
71            pc_limbs,
72            rd_data,
73        } = *cols;
74        builder.assert_bool(is_valid);
75
76        // We want to constrain rd = pc + imm (i32 add) where:
77        // - rd_data represents limbs of rd
78        // - pc_limbs are limbs of pc except the most and least significant limbs
79        // - imm_limbs are limbs of imm except the least significant limb
80
81        // We know that rd_data[0] is equal to the least significant limb of PC
82        // Thus, the intermediate value will be equal to PC without its most significant limb:
83        let intermed_val = rd_data[0]
84            + pc_limbs
85                .iter()
86                .enumerate()
87                .fold(AB::Expr::ZERO, |acc, (i, &val)| {
88                    acc + val * AB::Expr::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS))
89                });
90
91        // Compute the most significant limb of PC
92        let pc_msl = (from_pc - intermed_val)
93            * AB::F::from_canonical_usize(1 << (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
94                .inverse();
95
96        // The vector pc_limbs contains the actual limbs of PC in little endian order
97        let pc_limbs = [rd_data[0]]
98            .iter()
99            .chain(pc_limbs.iter())
100            .map(|x| (*x).into())
101            .chain([pc_msl])
102            .collect::<Vec<AB::Expr>>();
103
104        let mut carry: [AB::Expr; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
105        let carry_divide = AB::F::from_canonical_usize(1 << RV32_CELL_BITS).inverse();
106
107        // Don't need to constrain the least significant limb of the addition
108        // since we already know that rd_data[0] = pc_limbs[0] and the least significant limb of imm is 0
109        // Note: imm_limbs doesn't include the least significant limb so imm_limbs[i - 1] means the i-th limb of imm
110        for i in 1..RV32_REGISTER_NUM_LIMBS {
111            carry[i] = AB::Expr::from(carry_divide)
112                * (pc_limbs[i].clone() + imm_limbs[i - 1] - rd_data[i] + carry[i - 1].clone());
113            builder.when(is_valid).assert_bool(carry[i].clone());
114        }
115
116        // Range checking of rd_data entries to RV32_CELL_BITS bits
117        for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
118            self.bus
119                .send_range(rd_data[i * 2], rd_data[i * 2 + 1])
120                .eval(builder, is_valid);
121        }
122
123        // The immediate and PC limbs need range checking to ensure they're within [0, 2^RV32_CELL_BITS)
124        // Since we range check two items at a time, doing this way helps efficiently divide the limbs into groups of 2
125        // Note: range checking the limbs of immediate and PC separately would result in additional range checks
126        //       since they both have odd number of limbs that need to be range checked
127        let mut need_range_check: Vec<AB::Expr> = Vec::new();
128        for limb in imm_limbs {
129            need_range_check.push(limb.into());
130        }
131
132        // pc_limbs[0] is already range checked through rd_data[0]
133        for (i, limb) in pc_limbs.iter().skip(1).enumerate() {
134            if i == pc_limbs.len() - 1 {
135                // Range check the most significant limb of pc to be in [0, 2^{PC_BITS-(RV32_REGISTER_NUM_LIMBS-1)*RV32_CELL_BITS})
136                need_range_check.push(
137                    (*limb).clone()
138                        * AB::Expr::from_canonical_usize(
139                            1 << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS),
140                        ),
141                );
142            } else {
143                need_range_check.push((*limb).clone());
144            }
145        }
146
147        // need_range_check contains (RV32_REGISTER_NUM_LIMBS - 1) elements from imm_limbs
148        // and (RV32_REGISTER_NUM_LIMBS - 1) elements from pc_limbs
149        // Hence, is of even length 2*RV32_REGISTER_NUM_LIMBS - 2
150        assert_eq!(need_range_check.len() % 2, 0);
151        for pair in need_range_check.chunks_exact(2) {
152            self.bus
153                .send_range(pair[0].clone(), pair[1].clone())
154                .eval(builder, is_valid);
155        }
156
157        let imm = imm_limbs
158            .iter()
159            .enumerate()
160            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
161                acc + val * AB::Expr::from_canonical_u32(1 << (i * RV32_CELL_BITS))
162            });
163        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, AUIPC);
164        AdapterAirContext {
165            to_pc: None,
166            reads: [].into(),
167            writes: [rd_data.map(|x| x.into())].into(),
168            instruction: ImmInstruction {
169                is_valid: is_valid.into(),
170                opcode: expected_opcode,
171                immediate: imm,
172            }
173            .into(),
174        }
175    }
176
177    fn start_offset(&self) -> usize {
178        Rv32AuipcOpcode::CLASS_OFFSET
179    }
180}
181
182#[repr(C)]
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct Rv32AuipcCoreRecord<F> {
185    pub imm_limbs: [F; RV32_REGISTER_NUM_LIMBS - 1],
186    pub pc_limbs: [F; RV32_REGISTER_NUM_LIMBS - 2],
187    pub rd_data: [F; RV32_REGISTER_NUM_LIMBS],
188}
189
190pub struct Rv32AuipcCoreChip {
191    pub air: Rv32AuipcCoreAir,
192    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
193}
194
195impl Rv32AuipcCoreChip {
196    pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>) -> Self {
197        Self {
198            air: Rv32AuipcCoreAir {
199                bus: bitwise_lookup_chip.bus(),
200            },
201            bitwise_lookup_chip,
202        }
203    }
204}
205
206impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for Rv32AuipcCoreChip
207where
208    I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
209{
210    type Record = Rv32AuipcCoreRecord<F>;
211    type Air = Rv32AuipcCoreAir;
212
213    #[allow(clippy::type_complexity)]
214    fn execute_instruction(
215        &self,
216        instruction: &Instruction<F>,
217        from_pc: u32,
218        _reads: I::Reads,
219    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
220        let local_opcode = Rv32AuipcOpcode::from_usize(
221            instruction
222                .opcode
223                .local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET),
224        );
225        let imm = instruction.c.as_canonical_u32();
226        let rd_data = run_auipc(local_opcode, from_pc, imm);
227        let rd_data_field = rd_data.map(F::from_canonical_u32);
228
229        let output = AdapterRuntimeContext::without_pc([rd_data_field]);
230
231        let imm_limbs = array::from_fn(|i| (imm >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX);
232        let pc_limbs: [u32; RV32_REGISTER_NUM_LIMBS] =
233            array::from_fn(|i| (from_pc >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX);
234
235        for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
236            self.bitwise_lookup_chip
237                .request_range(rd_data[i * 2], rd_data[i * 2 + 1]);
238        }
239
240        let mut need_range_check: Vec<u32> = Vec::new();
241        for limb in imm_limbs {
242            need_range_check.push(limb);
243        }
244
245        for (i, limb) in pc_limbs.iter().skip(1).enumerate() {
246            if i == pc_limbs.len() - 1 {
247                need_range_check.push((*limb) << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS));
248            } else {
249                need_range_check.push(*limb);
250            }
251        }
252
253        for pair in need_range_check.chunks(2) {
254            self.bitwise_lookup_chip.request_range(pair[0], pair[1]);
255        }
256
257        Ok((
258            output,
259            Self::Record {
260                imm_limbs: imm_limbs.map(F::from_canonical_u32),
261                pc_limbs: array::from_fn(|i| F::from_canonical_u32(pc_limbs[i + 1])),
262                rd_data: rd_data.map(F::from_canonical_u32),
263            },
264        ))
265    }
266
267    fn get_opcode_name(&self, opcode: usize) -> String {
268        format!(
269            "{:?}",
270            Rv32AuipcOpcode::from_usize(opcode - Rv32AuipcOpcode::CLASS_OFFSET)
271        )
272    }
273
274    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
275        let core_cols: &mut Rv32AuipcCoreCols<F> = row_slice.borrow_mut();
276        core_cols.imm_limbs = record.imm_limbs;
277        core_cols.pc_limbs = record.pc_limbs;
278        core_cols.rd_data = record.rd_data;
279        core_cols.is_valid = F::ONE;
280    }
281
282    fn air(&self) -> &Self::Air {
283        &self.air
284    }
285}
286
287// returns rd_data
288pub(super) fn run_auipc(
289    _opcode: Rv32AuipcOpcode,
290    pc: u32,
291    imm: u32,
292) -> [u32; RV32_REGISTER_NUM_LIMBS] {
293    let rd = pc.wrapping_add(imm << RV32_CELL_BITS);
294    array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX)
295}