openvm_rv32im_circuit/adapters/
alu.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9        ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip,
10        VmAdapterInterface,
11    },
12    system::{
13        memory::{
14            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
15            MemoryAddress, MemoryController, OfflineMemory, RecordId,
16        },
17        program::ProgramBus,
18    },
19};
20use openvm_circuit_primitives::{
21    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
22    utils::not,
23};
24use openvm_circuit_primitives_derive::AlignedBorrow;
25use openvm_instructions::{
26    instruction::Instruction,
27    program::DEFAULT_PC_STEP,
28    riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
29};
30use openvm_stark_backend::{
31    interaction::InteractionBuilder,
32    p3_air::{AirBuilder, BaseAir},
33    p3_field::{Field, FieldAlgebra, PrimeField32},
34};
35use serde::{Deserialize, Serialize};
36
37use super::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
38
39/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e.
40/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c
41/// is an immediate).
42pub struct Rv32BaseAluAdapterChip<F: Field> {
43    pub air: Rv32BaseAluAdapterAir,
44    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
45    _marker: PhantomData<F>,
46}
47
48impl<F: PrimeField32> Rv32BaseAluAdapterChip<F> {
49    pub fn new(
50        execution_bus: ExecutionBus,
51        program_bus: ProgramBus,
52        memory_bridge: MemoryBridge,
53        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
54    ) -> Self {
55        Self {
56            air: Rv32BaseAluAdapterAir {
57                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
58                memory_bridge,
59                bitwise_lookup_bus: bitwise_lookup_chip.bus(),
60            },
61            bitwise_lookup_chip,
62            _marker: PhantomData,
63        }
64    }
65}
66
67#[repr(C)]
68#[derive(Clone, Debug, Serialize, Deserialize)]
69#[serde(bound = "F: Field")]
70pub struct Rv32BaseAluReadRecord<F: Field> {
71    /// Read register value from address space d=1
72    pub rs1: RecordId,
73    /// Either
74    /// - read rs2 register value or
75    /// - if `rs2_is_imm` is true, this is None
76    pub rs2: Option<RecordId>,
77    /// immediate value of rs2 or 0
78    pub rs2_imm: F,
79}
80
81#[repr(C)]
82#[derive(Clone, Debug, Serialize, Deserialize)]
83#[serde(bound = "F: Field")]
84pub struct Rv32BaseAluWriteRecord<F: Field> {
85    pub from_state: ExecutionState<u32>,
86    /// Write to destination register
87    pub rd: (RecordId, [F; 4]),
88}
89
90#[repr(C)]
91#[derive(AlignedBorrow)]
92pub struct Rv32BaseAluAdapterCols<T> {
93    pub from_state: ExecutionState<T>,
94    pub rd_ptr: T,
95    pub rs1_ptr: T,
96    /// Pointer if rs2 was a read, immediate value otherwise
97    pub rs2: T,
98    /// 1 if rs2 was a read, 0 if an immediate
99    pub rs2_as: T,
100    pub reads_aux: [MemoryReadAuxCols<T>; 2],
101    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
102}
103
104#[allow(dead_code)]
105#[derive(Clone, Copy, Debug, derive_new::new)]
106pub struct Rv32BaseAluAdapterAir {
107    pub(super) execution_bridge: ExecutionBridge,
108    pub(super) memory_bridge: MemoryBridge,
109    bitwise_lookup_bus: BitwiseOperationLookupBus,
110}
111
112impl<F: Field> BaseAir<F> for Rv32BaseAluAdapterAir {
113    fn width(&self) -> usize {
114        Rv32BaseAluAdapterCols::<F>::width()
115    }
116}
117
118impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32BaseAluAdapterAir {
119    type Interface = BasicAdapterInterface<
120        AB::Expr,
121        MinimalInstruction<AB::Expr>,
122        2,
123        1,
124        RV32_REGISTER_NUM_LIMBS,
125        RV32_REGISTER_NUM_LIMBS,
126    >;
127
128    fn eval(
129        &self,
130        builder: &mut AB,
131        local: &[AB::Var],
132        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
133    ) {
134        let local: &Rv32BaseAluAdapterCols<_> = local.borrow();
135        let timestamp = local.from_state.timestamp;
136        let mut timestamp_delta: usize = 0;
137        let mut timestamp_pp = || {
138            timestamp_delta += 1;
139            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
140        };
141
142        // If rs2 is an immediate value, constrain that:
143        // 1. It's a 16-bit two's complement integer (stored in rs2_limbs[0] and rs2_limbs[1])
144        // 2. It's properly sign-extended to 32-bits (the upper limbs must match the sign bit)
145        let rs2_limbs = ctx.reads[1].clone();
146        let rs2_sign = rs2_limbs[2].clone();
147        let rs2_imm = rs2_limbs[0].clone()
148            + rs2_limbs[1].clone() * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS)
149            + rs2_sign.clone() * AB::Expr::from_canonical_usize(1 << (2 * RV32_CELL_BITS));
150        builder.assert_bool(local.rs2_as);
151        let mut rs2_imm_when = builder.when(not(local.rs2_as));
152        rs2_imm_when.assert_eq(local.rs2, rs2_imm);
153        rs2_imm_when.assert_eq(rs2_sign.clone(), rs2_limbs[3].clone());
154        rs2_imm_when.assert_zero(
155            rs2_sign.clone()
156                * (AB::Expr::from_canonical_usize((1 << RV32_CELL_BITS) - 1) - rs2_sign),
157        );
158        self.bitwise_lookup_bus
159            .send_range(rs2_limbs[0].clone(), rs2_limbs[1].clone())
160            .eval(builder, ctx.instruction.is_valid.clone() - local.rs2_as);
161
162        self.memory_bridge
163            .read(
164                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
165                ctx.reads[0].clone(),
166                timestamp_pp(),
167                &local.reads_aux[0],
168            )
169            .eval(builder, ctx.instruction.is_valid.clone());
170
171        // This constraint ensures that the following memory read only occurs when `is_valid == 1`.
172        builder
173            .when(local.rs2_as)
174            .assert_one(ctx.instruction.is_valid.clone());
175        self.memory_bridge
176            .read(
177                MemoryAddress::new(local.rs2_as, local.rs2),
178                ctx.reads[1].clone(),
179                timestamp_pp(),
180                &local.reads_aux[1],
181            )
182            .eval(builder, local.rs2_as);
183
184        self.memory_bridge
185            .write(
186                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rd_ptr),
187                ctx.writes[0].clone(),
188                timestamp_pp(),
189                &local.writes_aux,
190            )
191            .eval(builder, ctx.instruction.is_valid.clone());
192
193        self.execution_bridge
194            .execute_and_increment_or_set_pc(
195                ctx.instruction.opcode,
196                [
197                    local.rd_ptr.into(),
198                    local.rs1_ptr.into(),
199                    local.rs2.into(),
200                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
201                    local.rs2_as.into(),
202                ],
203                local.from_state,
204                AB::F::from_canonical_usize(timestamp_delta),
205                (DEFAULT_PC_STEP, ctx.to_pc),
206            )
207            .eval(builder, ctx.instruction.is_valid);
208    }
209
210    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
211        let cols: &Rv32BaseAluAdapterCols<_> = local.borrow();
212        cols.from_state.pc
213    }
214}
215
216impl<F: PrimeField32> VmAdapterChip<F> for Rv32BaseAluAdapterChip<F> {
217    type ReadRecord = Rv32BaseAluReadRecord<F>;
218    type WriteRecord = Rv32BaseAluWriteRecord<F>;
219    type Air = Rv32BaseAluAdapterAir;
220    type Interface = BasicAdapterInterface<
221        F,
222        MinimalInstruction<F>,
223        2,
224        1,
225        RV32_REGISTER_NUM_LIMBS,
226        RV32_REGISTER_NUM_LIMBS,
227    >;
228
229    fn preprocess(
230        &mut self,
231        memory: &mut MemoryController<F>,
232        instruction: &Instruction<F>,
233    ) -> Result<(
234        <Self::Interface as VmAdapterInterface<F>>::Reads,
235        Self::ReadRecord,
236    )> {
237        let Instruction { b, c, d, e, .. } = *instruction;
238
239        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
240        debug_assert!(
241            e.as_canonical_u32() == RV32_IMM_AS || e.as_canonical_u32() == RV32_REGISTER_AS
242        );
243
244        let rs1 = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
245        let (rs2, rs2_data, rs2_imm) = if e.is_zero() {
246            let c_u32 = c.as_canonical_u32();
247            debug_assert_eq!(c_u32 >> 24, 0);
248            memory.increment_timestamp();
249            (
250                None,
251                [
252                    c_u32 as u8,
253                    (c_u32 >> 8) as u8,
254                    (c_u32 >> 16) as u8,
255                    (c_u32 >> 16) as u8,
256                ]
257                .map(F::from_canonical_u8),
258                c,
259            )
260        } else {
261            let rs2_read = memory.read::<RV32_REGISTER_NUM_LIMBS>(e, c);
262            (Some(rs2_read.0), rs2_read.1, F::ZERO)
263        };
264
265        Ok((
266            [rs1.1, rs2_data],
267            Self::ReadRecord {
268                rs1: rs1.0,
269                rs2,
270                rs2_imm,
271            },
272        ))
273    }
274
275    fn postprocess(
276        &mut self,
277        memory: &mut MemoryController<F>,
278        instruction: &Instruction<F>,
279        from_state: ExecutionState<u32>,
280        output: AdapterRuntimeContext<F, Self::Interface>,
281        _read_record: &Self::ReadRecord,
282    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
283        let Instruction { a, d, .. } = instruction;
284        let rd = memory.write(*d, *a, output.writes[0]);
285
286        let timestamp_delta = memory.timestamp() - from_state.timestamp;
287        debug_assert!(
288            timestamp_delta == 3,
289            "timestamp delta is {}, expected 3",
290            timestamp_delta
291        );
292
293        Ok((
294            ExecutionState {
295                pc: from_state.pc + DEFAULT_PC_STEP,
296                timestamp: memory.timestamp(),
297            },
298            Self::WriteRecord { from_state, rd },
299        ))
300    }
301
302    fn generate_trace_row(
303        &self,
304        row_slice: &mut [F],
305        read_record: Self::ReadRecord,
306        write_record: Self::WriteRecord,
307        memory: &OfflineMemory<F>,
308    ) {
309        let row_slice: &mut Rv32BaseAluAdapterCols<_> = row_slice.borrow_mut();
310        let aux_cols_factory = memory.aux_cols_factory();
311
312        let rd = memory.record_by_id(write_record.rd.0);
313        row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
314        row_slice.rd_ptr = rd.pointer;
315
316        let rs1 = memory.record_by_id(read_record.rs1);
317        let rs2 = read_record.rs2.map(|rs2| memory.record_by_id(rs2));
318        row_slice.rs1_ptr = rs1.pointer;
319
320        if let Some(rs2) = rs2 {
321            row_slice.rs2 = rs2.pointer;
322            row_slice.rs2_as = rs2.address_space;
323            aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]);
324            aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]);
325        } else {
326            row_slice.rs2 = read_record.rs2_imm;
327            row_slice.rs2_as = F::ZERO;
328            let rs2_imm = row_slice.rs2.as_canonical_u32();
329            let mask = (1 << RV32_CELL_BITS) - 1;
330            self.bitwise_lookup_chip
331                .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask);
332            aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]);
333            // row_slice.reads_aux[1] is disabled
334        }
335        aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux);
336    }
337
338    fn air(&self) -> &Self::Air {
339        &self.air
340    }
341}