openvm_rv32im_circuit/adapters/
branch.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    marker::PhantomData,
4};
5
6use openvm_circuit::{
7    arch::{
8        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9        ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip,
10        VmAdapterInterface,
11    },
12    system::{
13        memory::{
14            offline_checker::{MemoryBridge, MemoryReadAuxCols},
15            MemoryAddress, MemoryController, OfflineMemory, RecordId,
16        },
17        program::ProgramBus,
18    },
19};
20use openvm_circuit_primitives_derive::AlignedBorrow;
21use openvm_instructions::{
22    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
23};
24use openvm_stark_backend::{
25    interaction::InteractionBuilder,
26    p3_air::BaseAir,
27    p3_field::{Field, FieldAlgebra, PrimeField32},
28};
29use serde::{Deserialize, Serialize};
30
31use super::RV32_REGISTER_NUM_LIMBS;
32
33/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c.
34/// Operands d and e can only be 1.
35#[derive(Debug)]
36pub struct Rv32BranchAdapterChip<F: Field> {
37    pub air: Rv32BranchAdapterAir,
38    _marker: PhantomData<F>,
39}
40
41impl<F: PrimeField32> Rv32BranchAdapterChip<F> {
42    pub fn new(
43        execution_bus: ExecutionBus,
44        program_bus: ProgramBus,
45        memory_bridge: MemoryBridge,
46    ) -> Self {
47        Self {
48            air: Rv32BranchAdapterAir {
49                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
50                memory_bridge,
51            },
52            _marker: PhantomData,
53        }
54    }
55}
56
57#[repr(C)]
58#[derive(Debug, Serialize, Deserialize)]
59pub struct Rv32BranchReadRecord {
60    /// Read register value from address space d = 1
61    pub rs1: RecordId,
62    /// Read register value from address space e = 1
63    pub rs2: RecordId,
64}
65
66#[repr(C)]
67#[derive(Debug, Serialize, Deserialize)]
68pub struct Rv32BranchWriteRecord {
69    pub from_state: ExecutionState<u32>,
70}
71
72#[repr(C)]
73#[derive(AlignedBorrow)]
74pub struct Rv32BranchAdapterCols<T> {
75    pub from_state: ExecutionState<T>,
76    pub rs1_ptr: T,
77    pub rs2_ptr: T,
78    pub reads_aux: [MemoryReadAuxCols<T>; 2],
79}
80
81#[derive(Clone, Copy, Debug, derive_new::new)]
82pub struct Rv32BranchAdapterAir {
83    pub(super) execution_bridge: ExecutionBridge,
84    pub(super) memory_bridge: MemoryBridge,
85}
86
87impl<F: Field> BaseAir<F> for Rv32BranchAdapterAir {
88    fn width(&self) -> usize {
89        Rv32BranchAdapterCols::<F>::width()
90    }
91}
92
93impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32BranchAdapterAir {
94    type Interface =
95        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>;
96
97    fn eval(
98        &self,
99        builder: &mut AB,
100        local: &[AB::Var],
101        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
102    ) {
103        let local: &Rv32BranchAdapterCols<_> = local.borrow();
104        let timestamp = local.from_state.timestamp;
105        let mut timestamp_delta: usize = 0;
106        let mut timestamp_pp = || {
107            timestamp_delta += 1;
108            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
109        };
110
111        self.memory_bridge
112            .read(
113                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
114                ctx.reads[0].clone(),
115                timestamp_pp(),
116                &local.reads_aux[0],
117            )
118            .eval(builder, ctx.instruction.is_valid.clone());
119
120        self.memory_bridge
121            .read(
122                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs2_ptr),
123                ctx.reads[1].clone(),
124                timestamp_pp(),
125                &local.reads_aux[1],
126            )
127            .eval(builder, ctx.instruction.is_valid.clone());
128
129        self.execution_bridge
130            .execute_and_increment_or_set_pc(
131                ctx.instruction.opcode,
132                [
133                    local.rs1_ptr.into(),
134                    local.rs2_ptr.into(),
135                    ctx.instruction.immediate,
136                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
137                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
138                ],
139                local.from_state,
140                AB::F::from_canonical_usize(timestamp_delta),
141                (DEFAULT_PC_STEP, ctx.to_pc),
142            )
143            .eval(builder, ctx.instruction.is_valid);
144    }
145
146    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
147        let cols: &Rv32BranchAdapterCols<_> = local.borrow();
148        cols.from_state.pc
149    }
150}
151
152impl<F: PrimeField32> VmAdapterChip<F> for Rv32BranchAdapterChip<F> {
153    type ReadRecord = Rv32BranchReadRecord;
154    type WriteRecord = Rv32BranchWriteRecord;
155    type Air = Rv32BranchAdapterAir;
156    type Interface = BasicAdapterInterface<F, ImmInstruction<F>, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>;
157
158    fn preprocess(
159        &mut self,
160        memory: &mut MemoryController<F>,
161        instruction: &Instruction<F>,
162    ) -> Result<(
163        <Self::Interface as VmAdapterInterface<F>>::Reads,
164        Self::ReadRecord,
165    )> {
166        let Instruction { a, b, d, e, .. } = *instruction;
167
168        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
169        debug_assert_eq!(e.as_canonical_u32(), RV32_REGISTER_AS);
170
171        let rs1 = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, a);
172        let rs2 = memory.read::<RV32_REGISTER_NUM_LIMBS>(e, b);
173
174        Ok((
175            [rs1.1, rs2.1],
176            Self::ReadRecord {
177                rs1: rs1.0,
178                rs2: rs2.0,
179            },
180        ))
181    }
182
183    fn postprocess(
184        &mut self,
185        memory: &mut MemoryController<F>,
186        _instruction: &Instruction<F>,
187        from_state: ExecutionState<u32>,
188        output: AdapterRuntimeContext<F, Self::Interface>,
189        _read_record: &Self::ReadRecord,
190    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
191        let timestamp_delta = memory.timestamp() - from_state.timestamp;
192        debug_assert!(
193            timestamp_delta == 2,
194            "timestamp delta is {}, expected 2",
195            timestamp_delta
196        );
197
198        Ok((
199            ExecutionState {
200                pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
201                timestamp: memory.timestamp(),
202            },
203            Self::WriteRecord { from_state },
204        ))
205    }
206
207    fn generate_trace_row(
208        &self,
209        row_slice: &mut [F],
210        read_record: Self::ReadRecord,
211        write_record: Self::WriteRecord,
212        memory: &OfflineMemory<F>,
213    ) {
214        let aux_cols_factory = memory.aux_cols_factory();
215        let row_slice: &mut Rv32BranchAdapterCols<_> = row_slice.borrow_mut();
216        row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
217        let rs1 = memory.record_by_id(read_record.rs1);
218        let rs2 = memory.record_by_id(read_record.rs2);
219        row_slice.rs1_ptr = rs1.pointer;
220        row_slice.rs2_ptr = rs2.pointer;
221        aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]);
222        aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]);
223    }
224
225    fn air(&self) -> &Self::Air {
226        &self.air
227    }
228}