openvm_rv32im_circuit/adapters/
alu.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{
10            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
11            MemoryWriteBytesAuxRecord,
12        },
13        online::TracingMemory,
14        MemoryAddress, MemoryAuxColsFactory,
15    },
16};
17use openvm_circuit_primitives::{
18    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
19    utils::not,
20    AlignedBytesBorrow,
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24    instruction::Instruction,
25    program::DEFAULT_PC_STEP,
26    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
27};
28use openvm_stark_backend::{
29    interaction::InteractionBuilder,
30    p3_air::{AirBuilder, BaseAir},
31    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
32};
33
34use super::{
35    tracing_read, tracing_read_imm, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
36};
37
38#[repr(C)]
39#[derive(AlignedBorrow)]
40pub struct Rv32BaseAluAdapterCols<T> {
41    pub from_state: ExecutionState<T>,
42    pub rd_ptr: T,
43    pub rs1_ptr: T,
44    /// Pointer if rs2 was a read, immediate value otherwise
45    pub rs2: T,
46    /// 1 if rs2 was a read, 0 if an immediate
47    pub rs2_as: T,
48    pub reads_aux: [MemoryReadAuxCols<T>; 2],
49    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
50}
51
52/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e.
53/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c
54/// is an immediate).
55#[derive(Clone, Copy, Debug, derive_new::new)]
56pub struct Rv32BaseAluAdapterAir {
57    pub(super) execution_bridge: ExecutionBridge,
58    pub(super) memory_bridge: MemoryBridge,
59    bitwise_lookup_bus: BitwiseOperationLookupBus,
60}
61
62impl<F: Field> BaseAir<F> for Rv32BaseAluAdapterAir {
63    fn width(&self) -> usize {
64        Rv32BaseAluAdapterCols::<F>::width()
65    }
66}
67
68impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32BaseAluAdapterAir {
69    type Interface = BasicAdapterInterface<
70        AB::Expr,
71        MinimalInstruction<AB::Expr>,
72        2,
73        1,
74        RV32_REGISTER_NUM_LIMBS,
75        RV32_REGISTER_NUM_LIMBS,
76    >;
77
78    fn eval(
79        &self,
80        builder: &mut AB,
81        local: &[AB::Var],
82        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
83    ) {
84        let local: &Rv32BaseAluAdapterCols<_> = local.borrow();
85        let timestamp = local.from_state.timestamp;
86        let mut timestamp_delta: usize = 0;
87        let mut timestamp_pp = || {
88            timestamp_delta += 1;
89            timestamp + AB::F::from_usize(timestamp_delta - 1)
90        };
91
92        // If rs2 is an immediate value, constrain that:
93        // 1. It's a 16-bit two's complement integer (stored in rs2_limbs[0] and rs2_limbs[1])
94        // 2. It's properly sign-extended to 32-bits (the upper limbs must match the sign bit)
95        let rs2_limbs = ctx.reads[1].clone();
96        let rs2_sign = rs2_limbs[2].clone();
97        let rs2_imm = rs2_limbs[0].clone()
98            + rs2_limbs[1].clone() * AB::Expr::from_usize(1 << RV32_CELL_BITS)
99            + rs2_sign.clone() * AB::Expr::from_usize(1 << (2 * RV32_CELL_BITS));
100        builder.assert_bool(local.rs2_as);
101        let mut rs2_imm_when = builder.when(not(local.rs2_as));
102        rs2_imm_when.assert_eq(local.rs2, rs2_imm);
103        rs2_imm_when.assert_eq(rs2_sign.clone(), rs2_limbs[3].clone());
104        rs2_imm_when.assert_zero(
105            rs2_sign.clone() * (AB::Expr::from_usize((1 << RV32_CELL_BITS) - 1) - rs2_sign),
106        );
107        self.bitwise_lookup_bus
108            .send_range(rs2_limbs[0].clone(), rs2_limbs[1].clone())
109            .eval(builder, ctx.instruction.is_valid.clone() - local.rs2_as);
110
111        self.memory_bridge
112            .read(
113                MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local.rs1_ptr),
114                ctx.reads[0].clone(),
115                timestamp_pp(),
116                &local.reads_aux[0],
117            )
118            .eval(builder, ctx.instruction.is_valid.clone());
119
120        // This constraint ensures that the following memory read only occurs when `is_valid == 1`.
121        builder
122            .when(local.rs2_as)
123            .assert_one(ctx.instruction.is_valid.clone());
124        self.memory_bridge
125            .read(
126                MemoryAddress::new(local.rs2_as, local.rs2),
127                ctx.reads[1].clone(),
128                timestamp_pp(),
129                &local.reads_aux[1],
130            )
131            .eval(builder, local.rs2_as);
132
133        self.memory_bridge
134            .write(
135                MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local.rd_ptr),
136                ctx.writes[0].clone(),
137                timestamp_pp(),
138                &local.writes_aux,
139            )
140            .eval(builder, ctx.instruction.is_valid.clone());
141
142        self.execution_bridge
143            .execute_and_increment_or_set_pc(
144                ctx.instruction.opcode,
145                [
146                    local.rd_ptr.into(),
147                    local.rs1_ptr.into(),
148                    local.rs2.into(),
149                    AB::Expr::from_u32(RV32_REGISTER_AS),
150                    local.rs2_as.into(),
151                ],
152                local.from_state,
153                AB::F::from_usize(timestamp_delta),
154                (DEFAULT_PC_STEP, ctx.to_pc),
155            )
156            .eval(builder, ctx.instruction.is_valid);
157    }
158
159    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
160        let cols: &Rv32BaseAluAdapterCols<_> = local.borrow();
161        cols.from_state.pc
162    }
163}
164
165#[derive(Clone, derive_new::new)]
166pub struct Rv32BaseAluAdapterExecutor<const LIMB_BITS: usize>;
167
168#[derive(derive_new::new)]
169pub struct Rv32BaseAluAdapterFiller<const LIMB_BITS: usize> {
170    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
171}
172
173// Intermediate type that should not be copied or cloned and should be directly written to
174#[repr(C)]
175#[derive(AlignedBytesBorrow, Debug)]
176pub struct Rv32BaseAluAdapterRecord {
177    pub from_pc: u32,
178    pub from_timestamp: u32,
179
180    pub rd_ptr: u32,
181    pub rs1_ptr: u32,
182    /// Pointer if rs2 was a read, immediate value otherwise
183    pub rs2: u32,
184    /// 1 if rs2 was a read, 0 if an immediate
185    pub rs2_as: u8,
186
187    pub reads_aux: [MemoryReadAuxRecord; 2],
188    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
189}
190
191impl<F: PrimeField32, const LIMB_BITS: usize> AdapterTraceExecutor<F>
192    for Rv32BaseAluAdapterExecutor<LIMB_BITS>
193{
194    const WIDTH: usize = size_of::<Rv32BaseAluAdapterCols<u8>>();
195    type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2];
196    type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1];
197    type RecordMut<'a> = &'a mut Rv32BaseAluAdapterRecord;
198
199    #[inline(always)]
200    fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BaseAluAdapterRecord) {
201        record.from_pc = pc;
202        record.from_timestamp = memory.timestamp;
203    }
204
205    // @dev cannot get rid of double &mut due to trait
206    #[inline(always)]
207    fn read(
208        &self,
209        memory: &mut TracingMemory,
210        instruction: &Instruction<F>,
211        record: &mut &mut Rv32BaseAluAdapterRecord,
212    ) -> Self::ReadData {
213        let &Instruction { b, c, d, e, .. } = instruction;
214
215        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
216        debug_assert!(
217            e.as_canonical_u32() == RV32_REGISTER_AS || e.as_canonical_u32() == RV32_IMM_AS
218        );
219
220        record.rs1_ptr = b.as_canonical_u32();
221        let rs1 = tracing_read(
222            memory,
223            RV32_REGISTER_AS,
224            record.rs1_ptr,
225            &mut record.reads_aux[0].prev_timestamp,
226        );
227
228        let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS {
229            record.rs2_as = RV32_REGISTER_AS as u8;
230            record.rs2 = c.as_canonical_u32();
231
232            tracing_read(
233                memory,
234                RV32_REGISTER_AS,
235                record.rs2,
236                &mut record.reads_aux[1].prev_timestamp,
237            )
238        } else {
239            record.rs2_as = RV32_IMM_AS as u8;
240
241            tracing_read_imm(memory, c.as_canonical_u32(), &mut record.rs2)
242        };
243
244        [rs1, rs2]
245    }
246
247    #[inline(always)]
248    fn write(
249        &self,
250        memory: &mut TracingMemory,
251        instruction: &Instruction<F>,
252        data: Self::WriteData,
253        record: &mut &mut Rv32BaseAluAdapterRecord,
254    ) {
255        let &Instruction { a, d, .. } = instruction;
256
257        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
258
259        record.rd_ptr = a.as_canonical_u32();
260        tracing_write(
261            memory,
262            RV32_REGISTER_AS,
263            record.rd_ptr,
264            data[0],
265            &mut record.writes_aux.prev_timestamp,
266            &mut record.writes_aux.prev_data,
267        );
268    }
269}
270
271impl<F: PrimeField32, const LIMB_BITS: usize> AdapterTraceFiller<F>
272    for Rv32BaseAluAdapterFiller<LIMB_BITS>
273{
274    const WIDTH: usize = size_of::<Rv32BaseAluAdapterCols<u8>>();
275
276    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
277        // SAFETY: the following is highly unsafe. We are going to cast `adapter_row` to a record
278        // buffer, and then do an _overlapping_ write to the `adapter_row` as a row of field
279        // elements. This requires:
280        // - Cols struct should be repr(C) and we write in reverse order (to ensure non-overlapping)
281        // - Do not overwrite any reference in `record` before it has already been used or moved
282        // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic
283        //   otherwise)
284        // - adapter_row contains a valid Rv32BaseAluAdapterRecord representation
285        // - get_record_from_slice correctly interprets the bytes as Rv32BaseAluAdapterRecord
286        let record: &Rv32BaseAluAdapterRecord =
287            unsafe { get_record_from_slice(&mut adapter_row, ()) };
288        let adapter_row: &mut Rv32BaseAluAdapterCols<F> = adapter_row.borrow_mut();
289
290        // We must assign in reverse
291        const TIMESTAMP_DELTA: u32 = 2;
292        let mut timestamp = record.from_timestamp + TIMESTAMP_DELTA;
293
294        adapter_row
295            .writes_aux
296            .set_prev_data(record.writes_aux.prev_data.map(F::from_u8));
297        mem_helper.fill(
298            record.writes_aux.prev_timestamp,
299            timestamp,
300            adapter_row.writes_aux.as_mut(),
301        );
302        timestamp -= 1;
303
304        if record.rs2_as != 0 {
305            mem_helper.fill(
306                record.reads_aux[1].prev_timestamp,
307                timestamp,
308                adapter_row.reads_aux[1].as_mut(),
309            );
310        } else {
311            mem_helper.fill_zero(adapter_row.reads_aux[1].as_mut());
312            let rs2_imm = record.rs2;
313            let mask = (1 << RV32_CELL_BITS) - 1;
314            self.bitwise_lookup_chip
315                .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask);
316        }
317        timestamp -= 1;
318
319        mem_helper.fill(
320            record.reads_aux[0].prev_timestamp,
321            timestamp,
322            adapter_row.reads_aux[0].as_mut(),
323        );
324
325        adapter_row.rs2_as = F::from_u8(record.rs2_as);
326        adapter_row.rs2 = F::from_u32(record.rs2);
327        adapter_row.rs1_ptr = F::from_u32(record.rs1_ptr);
328        adapter_row.rd_ptr = F::from_u32(record.rd_ptr);
329        adapter_row.from_state.timestamp = F::from_u32(timestamp);
330        adapter_row.from_state.pc = F::from_u32(record.from_pc);
331    }
332}