openvm_rv32im_circuit/adapters/
jalr.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, SignedImmInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{
10            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
11            MemoryWriteBytesAuxRecord,
12        },
13        online::TracingMemory,
14        MemoryAddress, MemoryAuxColsFactory,
15    },
16};
17use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{
20    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
21};
22use openvm_stark_backend::{
23    interaction::InteractionBuilder,
24    p3_air::{AirBuilder, BaseAir},
25    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
26};
27
28use super::RV32_REGISTER_NUM_LIMBS;
29use crate::adapters::{tracing_read, tracing_write};
30
31#[repr(C)]
32#[derive(Debug, Clone, AlignedBorrow)]
33pub struct Rv32JalrAdapterCols<T> {
34    pub from_state: ExecutionState<T>,
35    pub rs1_ptr: T,
36    pub rs1_aux_cols: MemoryReadAuxCols<T>,
37    pub rd_ptr: T,
38    pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
39    /// Only writes if `needs_write`.
40    /// Sets `needs_write` to 0 iff `rd == x0`
41    pub needs_write: T,
42}
43
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32JalrAdapterAir {
46    pub(super) memory_bridge: MemoryBridge,
47    pub(super) execution_bridge: ExecutionBridge,
48}
49
50impl<F: Field> BaseAir<F> for Rv32JalrAdapterAir {
51    fn width(&self) -> usize {
52        Rv32JalrAdapterCols::<F>::width()
53    }
54}
55
56impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32JalrAdapterAir {
57    type Interface = BasicAdapterInterface<
58        AB::Expr,
59        SignedImmInstruction<AB::Expr>,
60        1,
61        1,
62        RV32_REGISTER_NUM_LIMBS,
63        RV32_REGISTER_NUM_LIMBS,
64    >;
65
66    fn eval(
67        &self,
68        builder: &mut AB,
69        local: &[AB::Var],
70        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
71    ) {
72        let local_cols: &Rv32JalrAdapterCols<AB::Var> = local.borrow();
73
74        let timestamp: AB::Var = local_cols.from_state.timestamp;
75        let mut timestamp_delta: usize = 0;
76        let mut timestamp_pp = || {
77            timestamp_delta += 1;
78            timestamp + AB::Expr::from_usize(timestamp_delta - 1)
79        };
80
81        let write_count = local_cols.needs_write;
82
83        builder.assert_bool(write_count);
84        builder
85            .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
86            .assert_zero(write_count);
87
88        self.memory_bridge
89            .read(
90                MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.rs1_ptr),
91                ctx.reads[0].clone(),
92                timestamp_pp(),
93                &local_cols.rs1_aux_cols,
94            )
95            .eval(builder, ctx.instruction.is_valid.clone());
96
97        self.memory_bridge
98            .write(
99                MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.rd_ptr),
100                ctx.writes[0].clone(),
101                timestamp_pp(),
102                &local_cols.rd_aux_cols,
103            )
104            .eval(builder, write_count);
105
106        let to_pc = ctx
107            .to_pc
108            .unwrap_or(local_cols.from_state.pc + AB::F::from_u32(DEFAULT_PC_STEP));
109
110        // regardless of `needs_write`, must always execute instruction when `is_valid`.
111        self.execution_bridge
112            .execute(
113                ctx.instruction.opcode,
114                [
115                    local_cols.rd_ptr.into(),
116                    local_cols.rs1_ptr.into(),
117                    ctx.instruction.immediate,
118                    AB::Expr::from_u32(RV32_REGISTER_AS),
119                    AB::Expr::ZERO,
120                    write_count.into(),
121                    ctx.instruction.imm_sign,
122                ],
123                local_cols.from_state,
124                ExecutionState {
125                    pc: to_pc,
126                    timestamp: timestamp + AB::F::from_usize(timestamp_delta),
127                },
128            )
129            .eval(builder, ctx.instruction.is_valid);
130    }
131
132    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
133        let cols: &Rv32JalrAdapterCols<_> = local.borrow();
134        cols.from_state.pc
135    }
136}
137
138#[repr(C)]
139#[derive(AlignedBytesBorrow, Debug)]
140pub struct Rv32JalrAdapterRecord {
141    pub from_pc: u32,
142    pub from_timestamp: u32,
143
144    pub rs1_ptr: u32,
145    // Will use u32::MAX to indicate no write
146    pub rd_ptr: u32,
147
148    pub reads_aux: MemoryReadAuxRecord,
149    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
150}
151
152// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd)
153#[derive(Clone, Copy, derive_new::new)]
154pub struct Rv32JalrAdapterExecutor;
155
156#[derive(Clone, Copy, derive_new::new)]
157pub struct Rv32JalrAdapterFiller;
158
159impl<F> AdapterTraceExecutor<F> for Rv32JalrAdapterExecutor
160where
161    F: PrimeField32,
162{
163    const WIDTH: usize = size_of::<Rv32JalrAdapterCols<u8>>();
164    type ReadData = [u8; RV32_REGISTER_NUM_LIMBS];
165    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
166    type RecordMut<'a> = &'a mut Rv32JalrAdapterRecord;
167
168    #[inline(always)]
169    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
170        record.from_pc = pc;
171        record.from_timestamp = memory.timestamp;
172    }
173
174    #[inline(always)]
175    fn read(
176        &self,
177        memory: &mut TracingMemory,
178        instruction: &Instruction<F>,
179        record: &mut Self::RecordMut<'_>,
180    ) -> Self::ReadData {
181        let &Instruction { b, d, .. } = instruction;
182
183        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
184
185        record.rs1_ptr = b.as_canonical_u32();
186        tracing_read(
187            memory,
188            RV32_REGISTER_AS,
189            b.as_canonical_u32(),
190            &mut record.reads_aux.prev_timestamp,
191        )
192    }
193
194    #[inline(always)]
195    fn write(
196        &self,
197        memory: &mut TracingMemory,
198        instruction: &Instruction<F>,
199        data: Self::WriteData,
200        record: &mut Self::RecordMut<'_>,
201    ) {
202        let &Instruction {
203            a, d, f: enabled, ..
204        } = instruction;
205
206        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
207
208        if enabled.is_one() {
209            record.rd_ptr = a.as_canonical_u32();
210
211            tracing_write(
212                memory,
213                RV32_REGISTER_AS,
214                a.as_canonical_u32(),
215                data,
216                &mut record.writes_aux.prev_timestamp,
217                &mut record.writes_aux.prev_data,
218            );
219        } else {
220            record.rd_ptr = u32::MAX;
221            memory.increment_timestamp();
222        }
223    }
224}
225
226impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32JalrAdapterFiller {
227    const WIDTH: usize = size_of::<Rv32JalrAdapterCols<u8>>();
228
229    #[inline(always)]
230    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
231        // SAFETY:
232        // - caller ensures `adapter_row` contains a valid record representation that was previously
233        //   written by the executor
234        // - get_record_from_slice correctly interprets the bytes as Rv32JalrAdapterRecord
235        let record: &Rv32JalrAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) };
236        let adapter_row: &mut Rv32JalrAdapterCols<F> = adapter_row.borrow_mut();
237
238        // We must assign in reverse
239        adapter_row.needs_write = F::from_bool(record.rd_ptr != u32::MAX);
240
241        if record.rd_ptr != u32::MAX {
242            adapter_row
243                .rd_aux_cols
244                .set_prev_data(record.writes_aux.prev_data.map(F::from_u8));
245            mem_helper.fill(
246                record.writes_aux.prev_timestamp,
247                record.from_timestamp + 1,
248                adapter_row.rd_aux_cols.as_mut(),
249            );
250            adapter_row.rd_ptr = F::from_u32(record.rd_ptr);
251        } else {
252            adapter_row.rd_ptr = F::ZERO;
253        }
254
255        mem_helper.fill(
256            record.reads_aux.prev_timestamp,
257            record.from_timestamp,
258            adapter_row.rs1_aux_cols.as_mut(),
259        );
260        adapter_row.rs1_ptr = F::from_u32(record.rs1_ptr);
261        adapter_row.from_state.timestamp = F::from_u32(record.from_timestamp);
262        adapter_row.from_state.pc = F::from_u32(record.from_pc);
263    }
264}