openvm_rv32im_circuit/adapters/
jalr.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, SignedImmInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{
10            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
11            MemoryWriteBytesAuxRecord,
12        },
13        online::TracingMemory,
14        MemoryAddress, MemoryAuxColsFactory,
15    },
16};
17use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow};
18use openvm_circuit_primitives_derive::AlignedBorrow;
19use openvm_instructions::{
20    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
21};
22use openvm_stark_backend::{
23    interaction::InteractionBuilder,
24    p3_air::{AirBuilder, BaseAir},
25    p3_field::{Field, FieldAlgebra, PrimeField32},
26};
27
28use super::RV32_REGISTER_NUM_LIMBS;
29use crate::adapters::{tracing_read, tracing_write};
30
31#[repr(C)]
32#[derive(Debug, Clone, AlignedBorrow)]
33pub struct Rv32JalrAdapterCols<T> {
34    pub from_state: ExecutionState<T>,
35    pub rs1_ptr: T,
36    pub rs1_aux_cols: MemoryReadAuxCols<T>,
37    pub rd_ptr: T,
38    pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
39    /// Only writes if `needs_write`.
40    /// Sets `needs_write` to 0 iff `rd == x0`
41    pub needs_write: T,
42}
43
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32JalrAdapterAir {
46    pub(super) memory_bridge: MemoryBridge,
47    pub(super) execution_bridge: ExecutionBridge,
48}
49
50impl<F: Field> BaseAir<F> for Rv32JalrAdapterAir {
51    fn width(&self) -> usize {
52        Rv32JalrAdapterCols::<F>::width()
53    }
54}
55
56impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32JalrAdapterAir {
57    type Interface = BasicAdapterInterface<
58        AB::Expr,
59        SignedImmInstruction<AB::Expr>,
60        1,
61        1,
62        RV32_REGISTER_NUM_LIMBS,
63        RV32_REGISTER_NUM_LIMBS,
64    >;
65
66    fn eval(
67        &self,
68        builder: &mut AB,
69        local: &[AB::Var],
70        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
71    ) {
72        let local_cols: &Rv32JalrAdapterCols<AB::Var> = local.borrow();
73
74        let timestamp: AB::Var = local_cols.from_state.timestamp;
75        let mut timestamp_delta: usize = 0;
76        let mut timestamp_pp = || {
77            timestamp_delta += 1;
78            timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
79        };
80
81        let write_count = local_cols.needs_write;
82
83        builder.assert_bool(write_count);
84        builder
85            .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
86            .assert_zero(write_count);
87
88        self.memory_bridge
89            .read(
90                MemoryAddress::new(
91                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
92                    local_cols.rs1_ptr,
93                ),
94                ctx.reads[0].clone(),
95                timestamp_pp(),
96                &local_cols.rs1_aux_cols,
97            )
98            .eval(builder, ctx.instruction.is_valid.clone());
99
100        self.memory_bridge
101            .write(
102                MemoryAddress::new(
103                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
104                    local_cols.rd_ptr,
105                ),
106                ctx.writes[0].clone(),
107                timestamp_pp(),
108                &local_cols.rd_aux_cols,
109            )
110            .eval(builder, write_count);
111
112        let to_pc = ctx
113            .to_pc
114            .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
115
116        // regardless of `needs_write`, must always execute instruction when `is_valid`.
117        self.execution_bridge
118            .execute(
119                ctx.instruction.opcode,
120                [
121                    local_cols.rd_ptr.into(),
122                    local_cols.rs1_ptr.into(),
123                    ctx.instruction.immediate,
124                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
125                    AB::Expr::ZERO,
126                    write_count.into(),
127                    ctx.instruction.imm_sign,
128                ],
129                local_cols.from_state,
130                ExecutionState {
131                    pc: to_pc,
132                    timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
133                },
134            )
135            .eval(builder, ctx.instruction.is_valid);
136    }
137
138    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
139        let cols: &Rv32JalrAdapterCols<_> = local.borrow();
140        cols.from_state.pc
141    }
142}
143
144#[repr(C)]
145#[derive(AlignedBytesBorrow, Debug)]
146pub struct Rv32JalrAdapterRecord {
147    pub from_pc: u32,
148    pub from_timestamp: u32,
149
150    pub rs1_ptr: u32,
151    // Will use u32::MAX to indicate no write
152    pub rd_ptr: u32,
153
154    pub reads_aux: MemoryReadAuxRecord,
155    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
156}
157
158// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd)
159#[derive(Clone, Copy, derive_new::new)]
160pub struct Rv32JalrAdapterExecutor;
161
162#[derive(Clone, Copy, derive_new::new)]
163pub struct Rv32JalrAdapterFiller;
164
165impl<F> AdapterTraceExecutor<F> for Rv32JalrAdapterExecutor
166where
167    F: PrimeField32,
168{
169    const WIDTH: usize = size_of::<Rv32JalrAdapterCols<u8>>();
170    type ReadData = [u8; RV32_REGISTER_NUM_LIMBS];
171    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
172    type RecordMut<'a> = &'a mut Rv32JalrAdapterRecord;
173
174    #[inline(always)]
175    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
176        record.from_pc = pc;
177        record.from_timestamp = memory.timestamp;
178    }
179
180    #[inline(always)]
181    fn read(
182        &self,
183        memory: &mut TracingMemory,
184        instruction: &Instruction<F>,
185        record: &mut Self::RecordMut<'_>,
186    ) -> Self::ReadData {
187        let &Instruction { b, d, .. } = instruction;
188
189        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
190
191        record.rs1_ptr = b.as_canonical_u32();
192        tracing_read(
193            memory,
194            RV32_REGISTER_AS,
195            b.as_canonical_u32(),
196            &mut record.reads_aux.prev_timestamp,
197        )
198    }
199
200    #[inline(always)]
201    fn write(
202        &self,
203        memory: &mut TracingMemory,
204        instruction: &Instruction<F>,
205        data: Self::WriteData,
206        record: &mut Self::RecordMut<'_>,
207    ) {
208        let &Instruction {
209            a, d, f: enabled, ..
210        } = instruction;
211
212        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
213
214        if enabled.is_one() {
215            record.rd_ptr = a.as_canonical_u32();
216
217            tracing_write(
218                memory,
219                RV32_REGISTER_AS,
220                a.as_canonical_u32(),
221                data,
222                &mut record.writes_aux.prev_timestamp,
223                &mut record.writes_aux.prev_data,
224            );
225        } else {
226            record.rd_ptr = u32::MAX;
227            memory.increment_timestamp();
228        }
229    }
230}
231
232impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32JalrAdapterFiller {
233    const WIDTH: usize = size_of::<Rv32JalrAdapterCols<u8>>();
234
235    #[inline(always)]
236    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
237        // SAFETY:
238        // - caller ensures `adapter_row` contains a valid record representation that was previously
239        //   written by the executor
240        // - get_record_from_slice correctly interprets the bytes as Rv32JalrAdapterRecord
241        let record: &Rv32JalrAdapterRecord = unsafe { get_record_from_slice(&mut adapter_row, ()) };
242        let adapter_row: &mut Rv32JalrAdapterCols<F> = adapter_row.borrow_mut();
243
244        // We must assign in reverse
245        adapter_row.needs_write = F::from_bool(record.rd_ptr != u32::MAX);
246
247        if record.rd_ptr != u32::MAX {
248            adapter_row
249                .rd_aux_cols
250                .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8));
251            mem_helper.fill(
252                record.writes_aux.prev_timestamp,
253                record.from_timestamp + 1,
254                adapter_row.rd_aux_cols.as_mut(),
255            );
256            adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr);
257        } else {
258            adapter_row.rd_ptr = F::ZERO;
259        }
260
261        mem_helper.fill(
262            record.reads_aux.prev_timestamp,
263            record.from_timestamp,
264            adapter_row.rs1_aux_cols.as_mut(),
265        );
266        adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
267        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
268        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
269    }
270}