openvm_rv32im_circuit/adapters/
rdwrite.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord},
10        online::TracingMemory,
11        MemoryAddress, MemoryAuxColsFactory,
12    },
13};
14use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
18};
19use openvm_stark_backend::{
20    interaction::InteractionBuilder,
21    p3_air::{AirBuilder, BaseAir},
22    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
23};
24
25use super::RV32_REGISTER_NUM_LIMBS;
26use crate::adapters::tracing_write;
27
28#[repr(C)]
29#[derive(Debug, Clone, AlignedBorrow)]
30pub struct Rv32RdWriteAdapterCols<T> {
31    pub from_state: ExecutionState<T>,
32    pub rd_ptr: T,
33    pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
34}
35
36#[repr(C)]
37#[derive(Debug, Clone, AlignedBorrow)]
38pub struct Rv32CondRdWriteAdapterCols<T> {
39    pub inner: Rv32RdWriteAdapterCols<T>,
40    pub needs_write: T,
41}
42
43/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32RdWriteAdapterAir {
46    pub(super) memory_bridge: MemoryBridge,
47    pub(super) execution_bridge: ExecutionBridge,
48}
49
50/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32CondRdWriteAdapterAir {
53    inner: Rv32RdWriteAdapterAir,
54}
55
56impl<F: Field> BaseAir<F> for Rv32RdWriteAdapterAir {
57    fn width(&self) -> usize {
58        Rv32RdWriteAdapterCols::<F>::width()
59    }
60}
61
62impl<F: Field> BaseAir<F> for Rv32CondRdWriteAdapterAir {
63    fn width(&self) -> usize {
64        Rv32CondRdWriteAdapterCols::<F>::width()
65    }
66}
67
68impl Rv32RdWriteAdapterAir {
69    /// If `needs_write` is provided:
70    /// - Only writes if `needs_write`.
71    /// - Sets operand `f = needs_write` in the instruction.
72    /// - Does not put any other constraints on `needs_write`
73    ///
74    /// Otherwise:
75    /// - Writes if `ctx.instruction.is_valid`.
76    /// - Sets operand `f` to default value of `0` in the instruction.
77    #[allow(clippy::type_complexity)]
78    fn conditional_eval<AB: InteractionBuilder>(
79        &self,
80        builder: &mut AB,
81        local_cols: &Rv32RdWriteAdapterCols<AB::Var>,
82        ctx: AdapterAirContext<
83            AB::Expr,
84            BasicAdapterInterface<
85                AB::Expr,
86                ImmInstruction<AB::Expr>,
87                0,
88                1,
89                0,
90                RV32_REGISTER_NUM_LIMBS,
91            >,
92        >,
93        needs_write: Option<AB::Expr>,
94    ) {
95        let timestamp: AB::Var = local_cols.from_state.timestamp;
96        let timestamp_delta = 1;
97        let (write_count, f) = if let Some(needs_write) = needs_write {
98            (needs_write.clone(), needs_write)
99        } else {
100            (ctx.instruction.is_valid.clone(), AB::Expr::ZERO)
101        };
102        self.memory_bridge
103            .write(
104                MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.rd_ptr),
105                ctx.writes[0].clone(),
106                timestamp,
107                &local_cols.rd_aux_cols,
108            )
109            .eval(builder, write_count);
110
111        let to_pc = ctx
112            .to_pc
113            .unwrap_or(local_cols.from_state.pc + AB::F::from_u32(DEFAULT_PC_STEP));
114        // regardless of `needs_write`, must always execute instruction when `is_valid`.
115        self.execution_bridge
116            .execute(
117                ctx.instruction.opcode,
118                [
119                    local_cols.rd_ptr.into(),
120                    AB::Expr::ZERO,
121                    ctx.instruction.immediate,
122                    AB::Expr::from_u32(RV32_REGISTER_AS),
123                    AB::Expr::ZERO,
124                    f,
125                ],
126                local_cols.from_state,
127                ExecutionState {
128                    pc: to_pc,
129                    timestamp: timestamp + AB::F::from_usize(timestamp_delta),
130                },
131            )
132            .eval(builder, ctx.instruction.is_valid);
133    }
134}
135
136impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32RdWriteAdapterAir {
137    type Interface =
138        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
139
140    fn eval(
141        &self,
142        builder: &mut AB,
143        local: &[AB::Var],
144        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
145    ) {
146        let local_cols: &Rv32RdWriteAdapterCols<AB::Var> = (*local).borrow();
147        self.conditional_eval(builder, local_cols, ctx, None);
148    }
149
150    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
151        let cols: &Rv32RdWriteAdapterCols<_> = local.borrow();
152        cols.from_state.pc
153    }
154}
155
156impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32CondRdWriteAdapterAir {
157    type Interface =
158        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
159
160    fn eval(
161        &self,
162        builder: &mut AB,
163        local: &[AB::Var],
164        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
165    ) {
166        let local_cols: &Rv32CondRdWriteAdapterCols<AB::Var> = (*local).borrow();
167
168        builder.assert_bool(local_cols.needs_write);
169        builder
170            .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
171            .assert_zero(local_cols.needs_write);
172
173        self.inner.conditional_eval(
174            builder,
175            &local_cols.inner,
176            ctx,
177            Some(local_cols.needs_write.into()),
178        );
179    }
180
181    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
182        let cols: &Rv32CondRdWriteAdapterCols<_> = local.borrow();
183        cols.inner.from_state.pc
184    }
185}
186
187/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1
188#[repr(C)]
189#[derive(AlignedBytesBorrow, Debug, Clone)]
190pub struct Rv32RdWriteAdapterRecord {
191    pub from_pc: u32,
192    pub from_timestamp: u32,
193
194    // Will use u32::MAX to indicate no write
195    pub rd_ptr: u32,
196    pub rd_aux_record: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
197}
198
199#[derive(Clone, Copy, derive_new::new)]
200pub struct Rv32RdWriteAdapterExecutor;
201
202#[derive(Clone, Copy, derive_new::new)]
203pub struct Rv32RdWriteAdapterFiller;
204
205impl<F> AdapterTraceExecutor<F> for Rv32RdWriteAdapterExecutor
206where
207    F: PrimeField32,
208{
209    const WIDTH: usize = size_of::<Rv32RdWriteAdapterCols<u8>>();
210    type ReadData = ();
211    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
212    type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord;
213
214    #[inline(always)]
215    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
216        record.from_pc = pc;
217        record.from_timestamp = memory.timestamp;
218    }
219
220    #[inline(always)]
221    fn read(
222        &self,
223        _memory: &mut TracingMemory,
224        _instruction: &Instruction<F>,
225        _record: &mut Self::RecordMut<'_>,
226    ) -> Self::ReadData {
227        // Rv32RdWriteAdapter doesn't read anything
228    }
229
230    #[inline(always)]
231    fn write(
232        &self,
233        memory: &mut TracingMemory,
234        instruction: &Instruction<F>,
235        data: Self::WriteData,
236        record: &mut Self::RecordMut<'_>,
237    ) {
238        let &Instruction { a, d, .. } = instruction;
239
240        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
241
242        record.rd_ptr = a.as_canonical_u32();
243        tracing_write(
244            memory,
245            RV32_REGISTER_AS,
246            record.rd_ptr,
247            data,
248            &mut record.rd_aux_record.prev_timestamp,
249            &mut record.rd_aux_record.prev_data,
250        );
251    }
252}
253
254impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32RdWriteAdapterFiller {
255    const WIDTH: usize = size_of::<Rv32RdWriteAdapterCols<u8>>();
256
257    #[inline(always)]
258    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
259        // SAFETY:
260        // - caller ensures `adapter_row` contains a valid record representation that was previously
261        //   written by the executor
262        // - get_record_from_slice correctly interprets the bytes as Rv32RdWriteAdapterRecord
263        let record: &Rv32RdWriteAdapterRecord =
264            unsafe { get_record_from_slice(&mut adapter_row, ()) };
265        let adapter_row: &mut Rv32RdWriteAdapterCols<F> = adapter_row.borrow_mut();
266
267        adapter_row
268            .rd_aux_cols
269            .set_prev_data(record.rd_aux_record.prev_data.map(F::from_u8));
270        mem_helper.fill(
271            record.rd_aux_record.prev_timestamp,
272            record.from_timestamp,
273            adapter_row.rd_aux_cols.as_mut(),
274        );
275        adapter_row.rd_ptr = F::from_u32(record.rd_ptr);
276        adapter_row.from_state.timestamp = F::from_u32(record.from_timestamp);
277        adapter_row.from_state.pc = F::from_u32(record.from_pc);
278    }
279}
280
281/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1
282#[derive(Clone, Copy, derive_new::new)]
283pub struct Rv32CondRdWriteAdapterExecutor {
284    inner: Rv32RdWriteAdapterExecutor,
285}
286
287#[derive(Clone, Copy, derive_new::new)]
288pub struct Rv32CondRdWriteAdapterFiller {
289    inner: Rv32RdWriteAdapterFiller,
290}
291
292impl<F> AdapterTraceExecutor<F> for Rv32CondRdWriteAdapterExecutor
293where
294    F: PrimeField32,
295{
296    const WIDTH: usize = size_of::<Rv32CondRdWriteAdapterCols<u8>>();
297    type ReadData = ();
298    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
299    type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord;
300
301    #[inline(always)]
302    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
303        record.from_pc = pc;
304        record.from_timestamp = memory.timestamp;
305    }
306
307    #[inline(always)]
308    fn read(
309        &self,
310        memory: &mut TracingMemory,
311        instruction: &Instruction<F>,
312        record: &mut Self::RecordMut<'_>,
313    ) -> Self::ReadData {
314        <Rv32RdWriteAdapterExecutor as AdapterTraceExecutor<F>>::read(
315            &self.inner,
316            memory,
317            instruction,
318            record,
319        )
320    }
321
322    #[inline(always)]
323    fn write(
324        &self,
325        memory: &mut TracingMemory,
326        instruction: &Instruction<F>,
327        data: Self::WriteData,
328        record: &mut Self::RecordMut<'_>,
329    ) {
330        let Instruction { f: enabled, .. } = instruction;
331
332        if enabled.is_one() {
333            <Rv32RdWriteAdapterExecutor as AdapterTraceExecutor<F>>::write(
334                &self.inner,
335                memory,
336                instruction,
337                data,
338                record,
339            );
340        } else {
341            memory.increment_timestamp();
342            record.rd_ptr = u32::MAX;
343        }
344    }
345}
346
347impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32CondRdWriteAdapterFiller {
348    const WIDTH: usize = size_of::<Rv32CondRdWriteAdapterCols<u8>>();
349
350    #[inline(always)]
351    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
352        // SAFETY:
353        // - caller ensures `adapter_row` contains a valid record representation that was previously
354        //   written by the executor
355        // - get_record_from_slice correctly interprets the bytes as Rv32RdWriteAdapterRecord
356        let record: &Rv32RdWriteAdapterRecord =
357            unsafe { get_record_from_slice(&mut adapter_row, ()) };
358        let adapter_cols: &mut Rv32CondRdWriteAdapterCols<F> = adapter_row.borrow_mut();
359
360        adapter_cols.needs_write = F::from_bool(record.rd_ptr != u32::MAX);
361
362        if record.rd_ptr != u32::MAX {
363            // SAFETY:
364            // - adapter_row has sufficient length for the split
365            // - size_of::<Rv32RdWriteAdapterCols<u8>>() is the correct split point
366            unsafe {
367                self.inner.fill_trace_row(
368                    mem_helper,
369                    adapter_row
370                        .split_at_mut_unchecked(size_of::<Rv32RdWriteAdapterCols<u8>>())
371                        .0,
372                )
373            };
374        } else {
375            adapter_cols.inner.rd_ptr = F::ZERO;
376            mem_helper.fill_zero(adapter_cols.inner.rd_aux_cols.as_mut());
377            adapter_cols.inner.from_state.timestamp = F::from_u32(record.from_timestamp);
378            adapter_cols.inner.from_state.pc = F::from_u32(record.from_pc);
379        }
380    }
381}