openvm_rv32im_circuit/adapters/
jalr.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9        ExecutionBus, ExecutionState, Result, SignedImmInstruction, VmAdapterAir, VmAdapterChip,
10        VmAdapterInterface,
11    },
12    system::{
13        memory::{
14            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
15            MemoryAddress, MemoryController, OfflineMemory, RecordId,
16        },
17        program::ProgramBus,
18    },
19};
20use openvm_circuit_primitives::utils::not;
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{
23    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
24};
25use openvm_stark_backend::{
26    interaction::InteractionBuilder,
27    p3_air::{AirBuilder, BaseAir},
28    p3_field::{Field, FieldAlgebra, PrimeField32},
29};
30use serde::{Deserialize, Serialize};
31
32use super::RV32_REGISTER_NUM_LIMBS;
33
34// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd)
35#[derive(Debug)]
36pub struct Rv32JalrAdapterChip<F: Field> {
37    pub air: Rv32JalrAdapterAir,
38    _marker: PhantomData<F>,
39}
40
41impl<F: PrimeField32> Rv32JalrAdapterChip<F> {
42    pub fn new(
43        execution_bus: ExecutionBus,
44        program_bus: ProgramBus,
45        memory_bridge: MemoryBridge,
46    ) -> Self {
47        Self {
48            air: Rv32JalrAdapterAir {
49                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
50                memory_bridge,
51            },
52            _marker: PhantomData,
53        }
54    }
55}
56#[repr(C)]
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Rv32JalrReadRecord {
59    pub rs1: RecordId,
60}
61
62#[repr(C)]
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct Rv32JalrWriteRecord {
65    pub from_state: ExecutionState<u32>,
66    pub rd_id: Option<RecordId>,
67}
68
69#[repr(C)]
70#[derive(Debug, Clone, AlignedBorrow)]
71pub struct Rv32JalrAdapterCols<T> {
72    pub from_state: ExecutionState<T>,
73    pub rs1_ptr: T,
74    pub rs1_aux_cols: MemoryReadAuxCols<T>,
75    pub rd_ptr: T,
76    pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
77    /// Only writes if `needs_write`.
78    /// Sets `needs_write` to 0 iff `rd == x0`
79    pub needs_write: T,
80}
81
82#[derive(Clone, Copy, Debug, derive_new::new)]
83pub struct Rv32JalrAdapterAir {
84    pub(super) memory_bridge: MemoryBridge,
85    pub(super) execution_bridge: ExecutionBridge,
86}
87
88impl<F: Field> BaseAir<F> for Rv32JalrAdapterAir {
89    fn width(&self) -> usize {
90        Rv32JalrAdapterCols::<F>::width()
91    }
92}
93
94impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32JalrAdapterAir {
95    type Interface = BasicAdapterInterface<
96        AB::Expr,
97        SignedImmInstruction<AB::Expr>,
98        1,
99        1,
100        RV32_REGISTER_NUM_LIMBS,
101        RV32_REGISTER_NUM_LIMBS,
102    >;
103
104    fn eval(
105        &self,
106        builder: &mut AB,
107        local: &[AB::Var],
108        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
109    ) {
110        let local_cols: &Rv32JalrAdapterCols<AB::Var> = local.borrow();
111
112        let timestamp: AB::Var = local_cols.from_state.timestamp;
113        let mut timestamp_delta: usize = 0;
114        let mut timestamp_pp = || {
115            timestamp_delta += 1;
116            timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
117        };
118
119        let write_count = local_cols.needs_write;
120
121        builder.assert_bool(write_count);
122        builder
123            .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
124            .assert_zero(write_count);
125
126        self.memory_bridge
127            .read(
128                MemoryAddress::new(
129                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
130                    local_cols.rs1_ptr,
131                ),
132                ctx.reads[0].clone(),
133                timestamp_pp(),
134                &local_cols.rs1_aux_cols,
135            )
136            .eval(builder, ctx.instruction.is_valid.clone());
137
138        self.memory_bridge
139            .write(
140                MemoryAddress::new(
141                    AB::F::from_canonical_u32(RV32_REGISTER_AS),
142                    local_cols.rd_ptr,
143                ),
144                ctx.writes[0].clone(),
145                timestamp_pp(),
146                &local_cols.rd_aux_cols,
147            )
148            .eval(builder, write_count);
149
150        let to_pc = ctx
151            .to_pc
152            .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
153
154        // regardless of `needs_write`, must always execute instruction when `is_valid`.
155        self.execution_bridge
156            .execute(
157                ctx.instruction.opcode,
158                [
159                    local_cols.rd_ptr.into(),
160                    local_cols.rs1_ptr.into(),
161                    ctx.instruction.immediate,
162                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
163                    AB::Expr::ZERO,
164                    write_count.into(),
165                    ctx.instruction.imm_sign,
166                ],
167                local_cols.from_state,
168                ExecutionState {
169                    pc: to_pc,
170                    timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
171                },
172            )
173            .eval(builder, ctx.instruction.is_valid);
174    }
175
176    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
177        let cols: &Rv32JalrAdapterCols<_> = local.borrow();
178        cols.from_state.pc
179    }
180}
181
182impl<F: PrimeField32> VmAdapterChip<F> for Rv32JalrAdapterChip<F> {
183    type ReadRecord = Rv32JalrReadRecord;
184    type WriteRecord = Rv32JalrWriteRecord;
185    type Air = Rv32JalrAdapterAir;
186    type Interface = BasicAdapterInterface<
187        F,
188        SignedImmInstruction<F>,
189        1,
190        1,
191        RV32_REGISTER_NUM_LIMBS,
192        RV32_REGISTER_NUM_LIMBS,
193    >;
194    fn preprocess(
195        &mut self,
196        memory: &mut MemoryController<F>,
197        instruction: &Instruction<F>,
198    ) -> Result<(
199        <Self::Interface as VmAdapterInterface<F>>::Reads,
200        Self::ReadRecord,
201    )> {
202        let Instruction { b, d, .. } = *instruction;
203        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
204
205        let rs1 = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
206
207        Ok(([rs1.1], Rv32JalrReadRecord { rs1: rs1.0 }))
208    }
209
210    fn postprocess(
211        &mut self,
212        memory: &mut MemoryController<F>,
213        instruction: &Instruction<F>,
214        from_state: ExecutionState<u32>,
215        output: AdapterRuntimeContext<F, Self::Interface>,
216        _read_record: &Self::ReadRecord,
217    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
218        let Instruction {
219            a, d, f: enabled, ..
220        } = *instruction;
221        let rd_id = if enabled != F::ZERO {
222            let (record_id, _) = memory.write(d, a, output.writes[0]);
223            Some(record_id)
224        } else {
225            memory.increment_timestamp();
226            None
227        };
228
229        Ok((
230            ExecutionState {
231                pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
232                timestamp: memory.timestamp(),
233            },
234            Self::WriteRecord { from_state, rd_id },
235        ))
236    }
237
238    fn generate_trace_row(
239        &self,
240        row_slice: &mut [F],
241        read_record: Self::ReadRecord,
242        write_record: Self::WriteRecord,
243        memory: &OfflineMemory<F>,
244    ) {
245        let aux_cols_factory = memory.aux_cols_factory();
246        let adapter_cols: &mut Rv32JalrAdapterCols<_> = row_slice.borrow_mut();
247        adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
248        let rs1 = memory.record_by_id(read_record.rs1);
249        adapter_cols.rs1_ptr = rs1.pointer;
250        aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols);
251        if let Some(id) = write_record.rd_id {
252            let rd = memory.record_by_id(id);
253            adapter_cols.rd_ptr = rd.pointer;
254            adapter_cols.needs_write = F::ONE;
255            aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols);
256        }
257    }
258
259    fn air(&self) -> &Self::Air {
260        &self.air
261    }
262}