openvm_rv32im_circuit/adapters/
rdwrite.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord},
10        online::TracingMemory,
11        MemoryAddress, MemoryAuxColsFactory,
12    },
13};
14use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
18};
19use openvm_stark_backend::{
20    interaction::InteractionBuilder,
21    p3_air::{AirBuilder, BaseAir},
22    p3_field::{Field, FieldAlgebra, PrimeField32},
23};
24
25use super::RV32_REGISTER_NUM_LIMBS;
26use crate::adapters::tracing_write;
27
28#[repr(C)]
29#[derive(Debug, Clone, AlignedBorrow)]
30pub struct Rv32RdWriteAdapterCols<T> {
31    pub from_state: ExecutionState<T>,
32    pub rd_ptr: T,
33    pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
34}
35
36#[repr(C)]
37#[derive(Debug, Clone, AlignedBorrow)]
38pub struct Rv32CondRdWriteAdapterCols<T> {
39    pub inner: Rv32RdWriteAdapterCols<T>,
40    pub needs_write: T,
41}
42
43/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32RdWriteAdapterAir {
46    pub(super) memory_bridge: MemoryBridge,
47    pub(super) execution_bridge: ExecutionBridge,
48}
49
50/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32CondRdWriteAdapterAir {
53    inner: Rv32RdWriteAdapterAir,
54}
55
56impl<F: Field> BaseAir<F> for Rv32RdWriteAdapterAir {
57    fn width(&self) -> usize {
58        Rv32RdWriteAdapterCols::<F>::width()
59    }
60}
61
62impl<F: Field> BaseAir<F> for Rv32CondRdWriteAdapterAir {
63    fn width(&self) -> usize {
64        Rv32CondRdWriteAdapterCols::<F>::width()
65    }
66}
67
68impl Rv32RdWriteAdapterAir {
69    /// If `needs_write` is provided:
70    /// - Only writes if `needs_write`.
71    /// - Sets operand `f = needs_write` in the instruction.
72    /// - Does not put any other constraints on `needs_write`
73    ///
74    /// Otherwise:
75    /// - Writes if `ctx.instruction.is_valid`.
76    /// - Sets operand `f` to default value of `0` in the instruction.
77    #[allow(clippy::type_complexity)]
78    fn conditional_eval<AB: InteractionBuilder>(
79        &self,
80        builder: &mut AB,
81        local_cols: &Rv32RdWriteAdapterCols<AB::Var>,
82        ctx: AdapterAirContext<
83            AB::Expr,
84            BasicAdapterInterface<
85                AB::Expr,
86                ImmInstruction<AB::Expr>,
87                0,
88                1,
89                0,
90                RV32_REGISTER_NUM_LIMBS,
91            >,
92        >,
93        needs_write: Option<AB::Expr>,
94    ) {
95        let timestamp: AB::Var = local_cols.from_state.timestamp;
96        let timestamp_delta = 1;
97        let (write_count, f) = if let Some(needs_write) = needs_write {
98            (needs_write.clone(), needs_write)
99        } else {
100            (ctx.instruction.is_valid.clone(), AB::Expr::ZERO)
101        };
102        self.memory_bridge
103            .write(
104                MemoryAddress::new(
105                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
106                    local_cols.rd_ptr,
107                ),
108                ctx.writes[0].clone(),
109                timestamp,
110                &local_cols.rd_aux_cols,
111            )
112            .eval(builder, write_count);
113
114        let to_pc = ctx
115            .to_pc
116            .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
117        // regardless of `needs_write`, must always execute instruction when `is_valid`.
118        self.execution_bridge
119            .execute(
120                ctx.instruction.opcode,
121                [
122                    local_cols.rd_ptr.into(),
123                    AB::Expr::ZERO,
124                    ctx.instruction.immediate,
125                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
126                    AB::Expr::ZERO,
127                    f,
128                ],
129                local_cols.from_state,
130                ExecutionState {
131                    pc: to_pc,
132                    timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
133                },
134            )
135            .eval(builder, ctx.instruction.is_valid);
136    }
137}
138
139impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32RdWriteAdapterAir {
140    type Interface =
141        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
142
143    fn eval(
144        &self,
145        builder: &mut AB,
146        local: &[AB::Var],
147        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
148    ) {
149        let local_cols: &Rv32RdWriteAdapterCols<AB::Var> = (*local).borrow();
150        self.conditional_eval(builder, local_cols, ctx, None);
151    }
152
153    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
154        let cols: &Rv32RdWriteAdapterCols<_> = local.borrow();
155        cols.from_state.pc
156    }
157}
158
159impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32CondRdWriteAdapterAir {
160    type Interface =
161        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
162
163    fn eval(
164        &self,
165        builder: &mut AB,
166        local: &[AB::Var],
167        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
168    ) {
169        let local_cols: &Rv32CondRdWriteAdapterCols<AB::Var> = (*local).borrow();
170
171        builder.assert_bool(local_cols.needs_write);
172        builder
173            .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
174            .assert_zero(local_cols.needs_write);
175
176        self.inner.conditional_eval(
177            builder,
178            &local_cols.inner,
179            ctx,
180            Some(local_cols.needs_write.into()),
181        );
182    }
183
184    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
185        let cols: &Rv32CondRdWriteAdapterCols<_> = local.borrow();
186        cols.inner.from_state.pc
187    }
188}
189
190/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1
191#[repr(C)]
192#[derive(AlignedBytesBorrow, Debug, Clone)]
193pub struct Rv32RdWriteAdapterRecord {
194    pub from_pc: u32,
195    pub from_timestamp: u32,
196
197    // Will use u32::MAX to indicate no write
198    pub rd_ptr: u32,
199    pub rd_aux_record: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
200}
201
202#[derive(Clone, Copy, derive_new::new)]
203pub struct Rv32RdWriteAdapterExecutor;
204
205#[derive(Clone, Copy, derive_new::new)]
206pub struct Rv32RdWriteAdapterFiller;
207
208impl<F> AdapterTraceExecutor<F> for Rv32RdWriteAdapterExecutor
209where
210    F: PrimeField32,
211{
212    const WIDTH: usize = size_of::<Rv32RdWriteAdapterCols<u8>>();
213    type ReadData = ();
214    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
215    type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord;
216
217    #[inline(always)]
218    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
219        record.from_pc = pc;
220        record.from_timestamp = memory.timestamp;
221    }
222
223    #[inline(always)]
224    fn read(
225        &self,
226        _memory: &mut TracingMemory,
227        _instruction: &Instruction<F>,
228        _record: &mut Self::RecordMut<'_>,
229    ) -> Self::ReadData {
230        // Rv32RdWriteAdapter doesn't read anything
231    }
232
233    #[inline(always)]
234    fn write(
235        &self,
236        memory: &mut TracingMemory,
237        instruction: &Instruction<F>,
238        data: Self::WriteData,
239        record: &mut Self::RecordMut<'_>,
240    ) {
241        let &Instruction { a, d, .. } = instruction;
242
243        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
244
245        record.rd_ptr = a.as_canonical_u32();
246        tracing_write(
247            memory,
248            RV32_REGISTER_AS,
249            record.rd_ptr,
250            data,
251            &mut record.rd_aux_record.prev_timestamp,
252            &mut record.rd_aux_record.prev_data,
253        );
254    }
255}
256
257impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32RdWriteAdapterFiller {
258    const WIDTH: usize = size_of::<Rv32RdWriteAdapterCols<u8>>();
259
260    #[inline(always)]
261    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
262        // SAFETY:
263        // - caller ensures `adapter_row` contains a valid record representation that was previously
264        //   written by the executor
265        // - get_record_from_slice correctly interprets the bytes as Rv32RdWriteAdapterRecord
266        let record: &Rv32RdWriteAdapterRecord =
267            unsafe { get_record_from_slice(&mut adapter_row, ()) };
268        let adapter_row: &mut Rv32RdWriteAdapterCols<F> = adapter_row.borrow_mut();
269
270        adapter_row
271            .rd_aux_cols
272            .set_prev_data(record.rd_aux_record.prev_data.map(F::from_canonical_u8));
273        mem_helper.fill(
274            record.rd_aux_record.prev_timestamp,
275            record.from_timestamp,
276            adapter_row.rd_aux_cols.as_mut(),
277        );
278        adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr);
279        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
280        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
281    }
282}
283
284/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1
285#[derive(Clone, Copy, derive_new::new)]
286pub struct Rv32CondRdWriteAdapterExecutor {
287    inner: Rv32RdWriteAdapterExecutor,
288}
289
290#[derive(Clone, Copy, derive_new::new)]
291pub struct Rv32CondRdWriteAdapterFiller {
292    inner: Rv32RdWriteAdapterFiller,
293}
294
295impl<F> AdapterTraceExecutor<F> for Rv32CondRdWriteAdapterExecutor
296where
297    F: PrimeField32,
298{
299    const WIDTH: usize = size_of::<Rv32CondRdWriteAdapterCols<u8>>();
300    type ReadData = ();
301    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
302    type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord;
303
304    #[inline(always)]
305    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
306        record.from_pc = pc;
307        record.from_timestamp = memory.timestamp;
308    }
309
310    #[inline(always)]
311    fn read(
312        &self,
313        memory: &mut TracingMemory,
314        instruction: &Instruction<F>,
315        record: &mut Self::RecordMut<'_>,
316    ) -> Self::ReadData {
317        <Rv32RdWriteAdapterExecutor as AdapterTraceExecutor<F>>::read(
318            &self.inner,
319            memory,
320            instruction,
321            record,
322        )
323    }
324
325    #[inline(always)]
326    fn write(
327        &self,
328        memory: &mut TracingMemory,
329        instruction: &Instruction<F>,
330        data: Self::WriteData,
331        record: &mut Self::RecordMut<'_>,
332    ) {
333        let Instruction { f: enabled, .. } = instruction;
334
335        if enabled.is_one() {
336            <Rv32RdWriteAdapterExecutor as AdapterTraceExecutor<F>>::write(
337                &self.inner,
338                memory,
339                instruction,
340                data,
341                record,
342            );
343        } else {
344            memory.increment_timestamp();
345            record.rd_ptr = u32::MAX;
346        }
347    }
348}
349
350impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32CondRdWriteAdapterFiller {
351    const WIDTH: usize = size_of::<Rv32CondRdWriteAdapterCols<u8>>();
352
353    #[inline(always)]
354    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
355        // SAFETY:
356        // - caller ensures `adapter_row` contains a valid record representation that was previously
357        //   written by the executor
358        // - get_record_from_slice correctly interprets the bytes as Rv32RdWriteAdapterRecord
359        let record: &Rv32RdWriteAdapterRecord =
360            unsafe { get_record_from_slice(&mut adapter_row, ()) };
361        let adapter_cols: &mut Rv32CondRdWriteAdapterCols<F> = adapter_row.borrow_mut();
362
363        adapter_cols.needs_write = F::from_bool(record.rd_ptr != u32::MAX);
364
365        if record.rd_ptr != u32::MAX {
366            // SAFETY:
367            // - adapter_row has sufficient length for the split
368            // - size_of::<Rv32RdWriteAdapterCols<u8>>() is the correct split point
369            unsafe {
370                self.inner.fill_trace_row(
371                    mem_helper,
372                    adapter_row
373                        .split_at_mut_unchecked(size_of::<Rv32RdWriteAdapterCols<u8>>())
374                        .0,
375                )
376            };
377        } else {
378            adapter_cols.inner.rd_ptr = F::ZERO;
379            mem_helper.fill_zero(adapter_cols.inner.rd_aux_cols.as_mut());
380            adapter_cols.inner.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
381            adapter_cols.inner.from_state.pc = F::from_canonical_u32(record.from_pc);
382        }
383    }
384}