openvm_rv32im_circuit/adapters/
alu.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{
10            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
11            MemoryWriteBytesAuxRecord,
12        },
13        online::TracingMemory,
14        MemoryAddress, MemoryAuxColsFactory,
15    },
16};
17use openvm_circuit_primitives::{
18    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
19    utils::not,
20    AlignedBytesBorrow,
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24    instruction::Instruction,
25    program::DEFAULT_PC_STEP,
26    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
27};
28use openvm_stark_backend::{
29    interaction::InteractionBuilder,
30    p3_air::{AirBuilder, BaseAir},
31    p3_field::{Field, FieldAlgebra, PrimeField32},
32};
33
34use super::{
35    tracing_read, tracing_read_imm, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
36};
37
38#[repr(C)]
39#[derive(AlignedBorrow)]
40pub struct Rv32BaseAluAdapterCols<T> {
41    pub from_state: ExecutionState<T>,
42    pub rd_ptr: T,
43    pub rs1_ptr: T,
44    /// Pointer if rs2 was a read, immediate value otherwise
45    pub rs2: T,
46    /// 1 if rs2 was a read, 0 if an immediate
47    pub rs2_as: T,
48    pub reads_aux: [MemoryReadAuxCols<T>; 2],
49    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
50}
51
52/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e.
53/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c
54/// is an immediate).
55#[derive(Clone, Copy, Debug, derive_new::new)]
56pub struct Rv32BaseAluAdapterAir {
57    pub(super) execution_bridge: ExecutionBridge,
58    pub(super) memory_bridge: MemoryBridge,
59    bitwise_lookup_bus: BitwiseOperationLookupBus,
60}
61
62impl<F: Field> BaseAir<F> for Rv32BaseAluAdapterAir {
63    fn width(&self) -> usize {
64        Rv32BaseAluAdapterCols::<F>::width()
65    }
66}
67
68impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32BaseAluAdapterAir {
69    type Interface = BasicAdapterInterface<
70        AB::Expr,
71        MinimalInstruction<AB::Expr>,
72        2,
73        1,
74        RV32_REGISTER_NUM_LIMBS,
75        RV32_REGISTER_NUM_LIMBS,
76    >;
77
78    fn eval(
79        &self,
80        builder: &mut AB,
81        local: &[AB::Var],
82        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
83    ) {
84        let local: &Rv32BaseAluAdapterCols<_> = local.borrow();
85        let timestamp = local.from_state.timestamp;
86        let mut timestamp_delta: usize = 0;
87        let mut timestamp_pp = || {
88            timestamp_delta += 1;
89            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
90        };
91
92        // If rs2 is an immediate value, constrain that:
93        // 1. It's a 16-bit two's complement integer (stored in rs2_limbs[0] and rs2_limbs[1])
94        // 2. It's properly sign-extended to 32-bits (the upper limbs must match the sign bit)
95        let rs2_limbs = ctx.reads[1].clone();
96        let rs2_sign = rs2_limbs[2].clone();
97        let rs2_imm = rs2_limbs[0].clone()
98            + rs2_limbs[1].clone() * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS)
99            + rs2_sign.clone() * AB::Expr::from_canonical_usize(1 << (2 * RV32_CELL_BITS));
100        builder.assert_bool(local.rs2_as);
101        let mut rs2_imm_when = builder.when(not(local.rs2_as));
102        rs2_imm_when.assert_eq(local.rs2, rs2_imm);
103        rs2_imm_when.assert_eq(rs2_sign.clone(), rs2_limbs[3].clone());
104        rs2_imm_when.assert_zero(
105            rs2_sign.clone()
106                * (AB::Expr::from_canonical_usize((1 << RV32_CELL_BITS) - 1) - rs2_sign),
107        );
108        self.bitwise_lookup_bus
109            .send_range(rs2_limbs[0].clone(), rs2_limbs[1].clone())
110            .eval(builder, ctx.instruction.is_valid.clone() - local.rs2_as);
111
112        self.memory_bridge
113            .read(
114                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
115                ctx.reads[0].clone(),
116                timestamp_pp(),
117                &local.reads_aux[0],
118            )
119            .eval(builder, ctx.instruction.is_valid.clone());
120
121        // This constraint ensures that the following memory read only occurs when `is_valid == 1`.
122        builder
123            .when(local.rs2_as)
124            .assert_one(ctx.instruction.is_valid.clone());
125        self.memory_bridge
126            .read(
127                MemoryAddress::new(local.rs2_as, local.rs2),
128                ctx.reads[1].clone(),
129                timestamp_pp(),
130                &local.reads_aux[1],
131            )
132            .eval(builder, local.rs2_as);
133
134        self.memory_bridge
135            .write(
136                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rd_ptr),
137                ctx.writes[0].clone(),
138                timestamp_pp(),
139                &local.writes_aux,
140            )
141            .eval(builder, ctx.instruction.is_valid.clone());
142
143        self.execution_bridge
144            .execute_and_increment_or_set_pc(
145                ctx.instruction.opcode,
146                [
147                    local.rd_ptr.into(),
148                    local.rs1_ptr.into(),
149                    local.rs2.into(),
150                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
151                    local.rs2_as.into(),
152                ],
153                local.from_state,
154                AB::F::from_canonical_usize(timestamp_delta),
155                (DEFAULT_PC_STEP, ctx.to_pc),
156            )
157            .eval(builder, ctx.instruction.is_valid);
158    }
159
160    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
161        let cols: &Rv32BaseAluAdapterCols<_> = local.borrow();
162        cols.from_state.pc
163    }
164}
165
166#[derive(Clone, derive_new::new)]
167pub struct Rv32BaseAluAdapterExecutor<const LIMB_BITS: usize>;
168
169#[derive(derive_new::new)]
170pub struct Rv32BaseAluAdapterFiller<const LIMB_BITS: usize> {
171    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
172}
173
174// Intermediate type that should not be copied or cloned and should be directly written to
175#[repr(C)]
176#[derive(AlignedBytesBorrow, Debug)]
177pub struct Rv32BaseAluAdapterRecord {
178    pub from_pc: u32,
179    pub from_timestamp: u32,
180
181    pub rd_ptr: u32,
182    pub rs1_ptr: u32,
183    /// Pointer if rs2 was a read, immediate value otherwise
184    pub rs2: u32,
185    /// 1 if rs2 was a read, 0 if an immediate
186    pub rs2_as: u8,
187
188    pub reads_aux: [MemoryReadAuxRecord; 2],
189    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
190}
191
192impl<F: PrimeField32, const LIMB_BITS: usize> AdapterTraceExecutor<F>
193    for Rv32BaseAluAdapterExecutor<LIMB_BITS>
194{
195    const WIDTH: usize = size_of::<Rv32BaseAluAdapterCols<u8>>();
196    type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2];
197    type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1];
198    type RecordMut<'a> = &'a mut Rv32BaseAluAdapterRecord;
199
200    #[inline(always)]
201    fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BaseAluAdapterRecord) {
202        record.from_pc = pc;
203        record.from_timestamp = memory.timestamp;
204    }
205
206    // @dev cannot get rid of double &mut due to trait
207    #[inline(always)]
208    fn read(
209        &self,
210        memory: &mut TracingMemory,
211        instruction: &Instruction<F>,
212        record: &mut &mut Rv32BaseAluAdapterRecord,
213    ) -> Self::ReadData {
214        let &Instruction { b, c, d, e, .. } = instruction;
215
216        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
217        debug_assert!(
218            e.as_canonical_u32() == RV32_REGISTER_AS || e.as_canonical_u32() == RV32_IMM_AS
219        );
220
221        record.rs1_ptr = b.as_canonical_u32();
222        let rs1 = tracing_read(
223            memory,
224            RV32_REGISTER_AS,
225            record.rs1_ptr,
226            &mut record.reads_aux[0].prev_timestamp,
227        );
228
229        let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS {
230            record.rs2_as = RV32_REGISTER_AS as u8;
231            record.rs2 = c.as_canonical_u32();
232
233            tracing_read(
234                memory,
235                RV32_REGISTER_AS,
236                record.rs2,
237                &mut record.reads_aux[1].prev_timestamp,
238            )
239        } else {
240            record.rs2_as = RV32_IMM_AS as u8;
241
242            tracing_read_imm(memory, c.as_canonical_u32(), &mut record.rs2)
243        };
244
245        [rs1, rs2]
246    }
247
248    #[inline(always)]
249    fn write(
250        &self,
251        memory: &mut TracingMemory,
252        instruction: &Instruction<F>,
253        data: Self::WriteData,
254        record: &mut &mut Rv32BaseAluAdapterRecord,
255    ) {
256        let &Instruction { a, d, .. } = instruction;
257
258        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
259
260        record.rd_ptr = a.as_canonical_u32();
261        tracing_write(
262            memory,
263            RV32_REGISTER_AS,
264            record.rd_ptr,
265            data[0],
266            &mut record.writes_aux.prev_timestamp,
267            &mut record.writes_aux.prev_data,
268        );
269    }
270}
271
272impl<F: PrimeField32, const LIMB_BITS: usize> AdapterTraceFiller<F>
273    for Rv32BaseAluAdapterFiller<LIMB_BITS>
274{
275    const WIDTH: usize = size_of::<Rv32BaseAluAdapterCols<u8>>();
276
277    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
278        // SAFETY: the following is highly unsafe. We are going to cast `adapter_row` to a record
279        // buffer, and then do an _overlapping_ write to the `adapter_row` as a row of field
280        // elements. This requires:
281        // - Cols struct should be repr(C) and we write in reverse order (to ensure non-overlapping)
282        // - Do not overwrite any reference in `record` before it has already been used or moved
283        // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic
284        //   otherwise)
285        // - adapter_row contains a valid Rv32BaseAluAdapterRecord representation
286        // - get_record_from_slice correctly interprets the bytes as Rv32BaseAluAdapterRecord
287        let record: &Rv32BaseAluAdapterRecord =
288            unsafe { get_record_from_slice(&mut adapter_row, ()) };
289        let adapter_row: &mut Rv32BaseAluAdapterCols<F> = adapter_row.borrow_mut();
290
291        // We must assign in reverse
292        const TIMESTAMP_DELTA: u32 = 2;
293        let mut timestamp = record.from_timestamp + TIMESTAMP_DELTA;
294
295        adapter_row
296            .writes_aux
297            .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8));
298        mem_helper.fill(
299            record.writes_aux.prev_timestamp,
300            timestamp,
301            adapter_row.writes_aux.as_mut(),
302        );
303        timestamp -= 1;
304
305        if record.rs2_as != 0 {
306            mem_helper.fill(
307                record.reads_aux[1].prev_timestamp,
308                timestamp,
309                adapter_row.reads_aux[1].as_mut(),
310            );
311        } else {
312            mem_helper.fill_zero(adapter_row.reads_aux[1].as_mut());
313            let rs2_imm = record.rs2;
314            let mask = (1 << RV32_CELL_BITS) - 1;
315            self.bitwise_lookup_chip
316                .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask);
317        }
318        timestamp -= 1;
319
320        mem_helper.fill(
321            record.reads_aux[0].prev_timestamp,
322            timestamp,
323            adapter_row.reads_aux[0].as_mut(),
324        );
325
326        adapter_row.rs2_as = F::from_canonical_u8(record.rs2_as);
327        adapter_row.rs2 = F::from_canonical_u32(record.rs2);
328        adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
329        adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr);
330        adapter_row.from_state.timestamp = F::from_canonical_u32(timestamp);
331        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
332    }
333}