openvm_rv32im_circuit/adapters/
rdwrite.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9        ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip,
10        VmAdapterInterface,
11    },
12    system::{
13        memory::{
14            offline_checker::{MemoryBridge, MemoryWriteAuxCols},
15            MemoryAddress, MemoryController, OfflineMemory, RecordId,
16        },
17        program::ProgramBus,
18    },
19};
20use openvm_circuit_primitives::utils::not;
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{
23    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
24};
25use openvm_stark_backend::{
26    interaction::InteractionBuilder,
27    p3_air::{AirBuilder, BaseAir},
28    p3_field::{Field, FieldAlgebra, PrimeField32},
29};
30use serde::{Deserialize, Serialize};
31
32use super::RV32_REGISTER_NUM_LIMBS;
33
34/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1
35#[derive(Debug)]
36pub struct Rv32RdWriteAdapterChip<F: Field> {
37    pub air: Rv32RdWriteAdapterAir,
38    _marker: PhantomData<F>,
39}
40
41/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1
42#[derive(Debug)]
43pub struct Rv32CondRdWriteAdapterChip<F: Field> {
44    /// Do not use the inner air directly, use `air` instead.
45    inner: Rv32RdWriteAdapterChip<F>,
46    pub air: Rv32CondRdWriteAdapterAir,
47}
48
49impl<F: PrimeField32> Rv32RdWriteAdapterChip<F> {
50    pub fn new(
51        execution_bus: ExecutionBus,
52        program_bus: ProgramBus,
53        memory_bridge: MemoryBridge,
54    ) -> Self {
55        Self {
56            air: Rv32RdWriteAdapterAir {
57                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
58                memory_bridge,
59            },
60            _marker: PhantomData,
61        }
62    }
63}
64
65impl<F: PrimeField32> Rv32CondRdWriteAdapterChip<F> {
66    pub fn new(
67        execution_bus: ExecutionBus,
68        program_bus: ProgramBus,
69        memory_bridge: MemoryBridge,
70    ) -> Self {
71        let inner = Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge);
72        let air = Rv32CondRdWriteAdapterAir { inner: inner.air };
73        Self { inner, air }
74    }
75}
76
77#[repr(C)]
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Rv32RdWriteWriteRecord {
80    pub from_state: ExecutionState<u32>,
81    pub rd_id: Option<RecordId>,
82}
83
84#[repr(C)]
85#[derive(Debug, Clone, AlignedBorrow)]
86pub struct Rv32RdWriteAdapterCols<T> {
87    pub from_state: ExecutionState<T>,
88    pub rd_ptr: T,
89    pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
90}
91
92#[repr(C)]
93#[derive(Debug, Clone, AlignedBorrow)]
94pub struct Rv32CondRdWriteAdapterCols<T> {
95    inner: Rv32RdWriteAdapterCols<T>,
96    pub needs_write: T,
97}
98
99#[derive(Clone, Copy, Debug, derive_new::new)]
100pub struct Rv32RdWriteAdapterAir {
101    pub(super) memory_bridge: MemoryBridge,
102    pub(super) execution_bridge: ExecutionBridge,
103}
104
105#[derive(Clone, Copy, Debug, derive_new::new)]
106pub struct Rv32CondRdWriteAdapterAir {
107    inner: Rv32RdWriteAdapterAir,
108}
109
110impl<F: Field> BaseAir<F> for Rv32RdWriteAdapterAir {
111    fn width(&self) -> usize {
112        Rv32RdWriteAdapterCols::<F>::width()
113    }
114}
115
116impl<F: Field> BaseAir<F> for Rv32CondRdWriteAdapterAir {
117    fn width(&self) -> usize {
118        Rv32CondRdWriteAdapterCols::<F>::width()
119    }
120}
121
122impl Rv32RdWriteAdapterAir {
123    /// If `needs_write` is provided:
124    /// - Only writes if `needs_write`.
125    /// - Sets operand `f = needs_write` in the instruction.
126    /// - Does not put any other constraints on `needs_write`
127    ///
128    /// Otherwise:
129    /// - Writes if `ctx.instruction.is_valid`.
130    /// - Sets operand `f` to default value of `0` in the instruction.
131    #[allow(clippy::type_complexity)]
132    fn conditional_eval<AB: InteractionBuilder>(
133        &self,
134        builder: &mut AB,
135        local_cols: &Rv32RdWriteAdapterCols<AB::Var>,
136        ctx: AdapterAirContext<
137            AB::Expr,
138            BasicAdapterInterface<
139                AB::Expr,
140                ImmInstruction<AB::Expr>,
141                0,
142                1,
143                0,
144                RV32_REGISTER_NUM_LIMBS,
145            >,
146        >,
147        needs_write: Option<AB::Expr>,
148    ) {
149        let timestamp: AB::Var = local_cols.from_state.timestamp;
150        let timestamp_delta = 1;
151        let (write_count, f) = if let Some(needs_write) = needs_write {
152            (needs_write.clone(), needs_write)
153        } else {
154            (ctx.instruction.is_valid.clone(), AB::Expr::ZERO)
155        };
156        self.memory_bridge
157            .write(
158                MemoryAddress::new(
159                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
160                    local_cols.rd_ptr,
161                ),
162                ctx.writes[0].clone(),
163                timestamp,
164                &local_cols.rd_aux_cols,
165            )
166            .eval(builder, write_count);
167
168        let to_pc = ctx
169            .to_pc
170            .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
171        // regardless of `needs_write`, must always execute instruction when `is_valid`.
172        self.execution_bridge
173            .execute(
174                ctx.instruction.opcode,
175                [
176                    local_cols.rd_ptr.into(),
177                    AB::Expr::ZERO,
178                    ctx.instruction.immediate,
179                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
180                    AB::Expr::ZERO,
181                    f,
182                ],
183                local_cols.from_state,
184                ExecutionState {
185                    pc: to_pc,
186                    timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
187                },
188            )
189            .eval(builder, ctx.instruction.is_valid);
190    }
191}
192
193impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32RdWriteAdapterAir {
194    type Interface =
195        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
196
197    fn eval(
198        &self,
199        builder: &mut AB,
200        local: &[AB::Var],
201        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
202    ) {
203        let local_cols: &Rv32RdWriteAdapterCols<AB::Var> = (*local).borrow();
204        self.conditional_eval(builder, local_cols, ctx, None);
205    }
206
207    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
208        let cols: &Rv32RdWriteAdapterCols<_> = local.borrow();
209        cols.from_state.pc
210    }
211}
212
213impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32CondRdWriteAdapterAir {
214    type Interface =
215        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
216
217    fn eval(
218        &self,
219        builder: &mut AB,
220        local: &[AB::Var],
221        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
222    ) {
223        let local_cols: &Rv32CondRdWriteAdapterCols<AB::Var> = (*local).borrow();
224
225        builder.assert_bool(local_cols.needs_write);
226        builder
227            .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
228            .assert_zero(local_cols.needs_write);
229
230        self.inner.conditional_eval(
231            builder,
232            &local_cols.inner,
233            ctx,
234            Some(local_cols.needs_write.into()),
235        );
236    }
237
238    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
239        let cols: &Rv32CondRdWriteAdapterCols<_> = local.borrow();
240        cols.inner.from_state.pc
241    }
242}
243
244impl<F: PrimeField32> VmAdapterChip<F> for Rv32RdWriteAdapterChip<F> {
245    type ReadRecord = ();
246    type WriteRecord = Rv32RdWriteWriteRecord;
247    type Air = Rv32RdWriteAdapterAir;
248    type Interface = BasicAdapterInterface<F, ImmInstruction<F>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
249
250    fn preprocess(
251        &mut self,
252        _memory: &mut MemoryController<F>,
253        instruction: &Instruction<F>,
254    ) -> Result<(
255        <Self::Interface as VmAdapterInterface<F>>::Reads,
256        Self::ReadRecord,
257    )> {
258        let d = instruction.d;
259        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
260
261        Ok(([], ()))
262    }
263
264    fn postprocess(
265        &mut self,
266        memory: &mut MemoryController<F>,
267        instruction: &Instruction<F>,
268        from_state: ExecutionState<u32>,
269        output: AdapterRuntimeContext<F, Self::Interface>,
270        _read_record: &Self::ReadRecord,
271    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
272        let Instruction { a, d, .. } = *instruction;
273        let (rd_id, _) = memory.write(d, a, output.writes[0]);
274
275        Ok((
276            ExecutionState {
277                pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
278                timestamp: memory.timestamp(),
279            },
280            Self::WriteRecord {
281                from_state,
282                rd_id: Some(rd_id),
283            },
284        ))
285    }
286
287    fn generate_trace_row(
288        &self,
289        row_slice: &mut [F],
290        _read_record: Self::ReadRecord,
291        write_record: Self::WriteRecord,
292        memory: &OfflineMemory<F>,
293    ) {
294        let aux_cols_factory = memory.aux_cols_factory();
295        let adapter_cols: &mut Rv32RdWriteAdapterCols<F> = row_slice.borrow_mut();
296        adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
297        let rd = memory.record_by_id(write_record.rd_id.unwrap());
298        adapter_cols.rd_ptr = rd.pointer;
299        aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols);
300    }
301
302    fn air(&self) -> &Self::Air {
303        &self.air
304    }
305}
306
307impl<F: PrimeField32> VmAdapterChip<F> for Rv32CondRdWriteAdapterChip<F> {
308    type ReadRecord = ();
309    type WriteRecord = Rv32RdWriteWriteRecord;
310    type Air = Rv32CondRdWriteAdapterAir;
311    type Interface = BasicAdapterInterface<F, ImmInstruction<F>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
312
313    fn preprocess(
314        &mut self,
315        memory: &mut MemoryController<F>,
316        instruction: &Instruction<F>,
317    ) -> Result<(
318        <Self::Interface as VmAdapterInterface<F>>::Reads,
319        Self::ReadRecord,
320    )> {
321        self.inner.preprocess(memory, instruction)
322    }
323
324    fn postprocess(
325        &mut self,
326        memory: &mut MemoryController<F>,
327        instruction: &Instruction<F>,
328        from_state: ExecutionState<u32>,
329        output: AdapterRuntimeContext<F, Self::Interface>,
330        _read_record: &Self::ReadRecord,
331    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
332        let Instruction { a, d, .. } = *instruction;
333        let rd_id = if instruction.f != F::ZERO {
334            let (rd_id, _) = memory.write(d, a, output.writes[0]);
335            Some(rd_id)
336        } else {
337            memory.increment_timestamp();
338            None
339        };
340
341        Ok((
342            ExecutionState {
343                pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
344                timestamp: memory.timestamp(),
345            },
346            Self::WriteRecord { from_state, rd_id },
347        ))
348    }
349
350    fn generate_trace_row(
351        &self,
352        row_slice: &mut [F],
353        _read_record: Self::ReadRecord,
354        write_record: Self::WriteRecord,
355        memory: &OfflineMemory<F>,
356    ) {
357        let aux_cols_factory = memory.aux_cols_factory();
358        let adapter_cols: &mut Rv32CondRdWriteAdapterCols<F> = row_slice.borrow_mut();
359        adapter_cols.inner.from_state = write_record.from_state.map(F::from_canonical_u32);
360        if let Some(rd_id) = write_record.rd_id {
361            let rd = memory.record_by_id(rd_id);
362            adapter_cols.inner.rd_ptr = rd.pointer;
363            aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.inner.rd_aux_cols);
364            adapter_cols.needs_write = F::ONE;
365        }
366    }
367
368    fn air(&self) -> &Self::Air {
369        &self.air
370    }
371}