openvm_rv32im_circuit/adapters/
branch.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::{
5        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
7    },
8    system::memory::{
9        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord},
10        online::TracingMemory,
11        MemoryAddress, MemoryAuxColsFactory,
12    },
13};
14use openvm_circuit_primitives::AlignedBytesBorrow;
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
18};
19use openvm_stark_backend::{
20    interaction::InteractionBuilder,
21    p3_air::BaseAir,
22    p3_field::{Field, FieldAlgebra, PrimeField32},
23};
24
25use super::RV32_REGISTER_NUM_LIMBS;
26use crate::adapters::tracing_read;
27
28#[repr(C)]
29#[derive(AlignedBorrow)]
30pub struct Rv32BranchAdapterCols<T> {
31    pub from_state: ExecutionState<T>,
32    pub rs1_ptr: T,
33    pub rs2_ptr: T,
34    pub reads_aux: [MemoryReadAuxCols<T>; 2],
35}
36
37#[derive(Clone, Copy, Debug, derive_new::new)]
38pub struct Rv32BranchAdapterAir {
39    pub(super) execution_bridge: ExecutionBridge,
40    pub(super) memory_bridge: MemoryBridge,
41}
42
43impl<F: Field> BaseAir<F> for Rv32BranchAdapterAir {
44    fn width(&self) -> usize {
45        Rv32BranchAdapterCols::<F>::width()
46    }
47}
48
49impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32BranchAdapterAir {
50    type Interface =
51        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>;
52
53    fn eval(
54        &self,
55        builder: &mut AB,
56        local: &[AB::Var],
57        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
58    ) {
59        let local: &Rv32BranchAdapterCols<_> = local.borrow();
60        let timestamp = local.from_state.timestamp;
61        let mut timestamp_delta: usize = 0;
62        let mut timestamp_pp = || {
63            timestamp_delta += 1;
64            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
65        };
66
67        self.memory_bridge
68            .read(
69                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
70                ctx.reads[0].clone(),
71                timestamp_pp(),
72                &local.reads_aux[0],
73            )
74            .eval(builder, ctx.instruction.is_valid.clone());
75
76        self.memory_bridge
77            .read(
78                MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs2_ptr),
79                ctx.reads[1].clone(),
80                timestamp_pp(),
81                &local.reads_aux[1],
82            )
83            .eval(builder, ctx.instruction.is_valid.clone());
84
85        self.execution_bridge
86            .execute_and_increment_or_set_pc(
87                ctx.instruction.opcode,
88                [
89                    local.rs1_ptr.into(),
90                    local.rs2_ptr.into(),
91                    ctx.instruction.immediate,
92                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
93                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
94                ],
95                local.from_state,
96                AB::F::from_canonical_usize(timestamp_delta),
97                (DEFAULT_PC_STEP, ctx.to_pc),
98            )
99            .eval(builder, ctx.instruction.is_valid);
100    }
101
102    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
103        let cols: &Rv32BranchAdapterCols<_> = local.borrow();
104        cols.from_state.pc
105    }
106}
107
108#[repr(C)]
109#[derive(AlignedBytesBorrow, Debug)]
110pub struct Rv32BranchAdapterRecord {
111    pub from_pc: u32,
112    pub from_timestamp: u32,
113    pub rs1_ptr: u32,
114    pub rs2_ptr: u32,
115    pub reads_aux: [MemoryReadAuxRecord; 2],
116}
117
118/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c.
119/// Operands d and e can only be 1.
120#[derive(Clone, Copy, derive_new::new)]
121pub struct Rv32BranchAdapterExecutor;
122
123#[derive(derive_new::new)]
124pub struct Rv32BranchAdapterFiller;
125
126impl<F> AdapterTraceExecutor<F> for Rv32BranchAdapterExecutor
127where
128    F: PrimeField32,
129{
130    const WIDTH: usize = size_of::<Rv32BranchAdapterCols<u8>>();
131    type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2];
132    type WriteData = ();
133    type RecordMut<'a> = &'a mut Rv32BranchAdapterRecord;
134
135    #[inline(always)]
136    fn start(pc: u32, memory: &TracingMemory, record: &mut &mut Rv32BranchAdapterRecord) {
137        record.from_pc = pc;
138        record.from_timestamp = memory.timestamp;
139    }
140
141    #[inline(always)]
142    fn read(
143        &self,
144        memory: &mut TracingMemory,
145        instruction: &Instruction<F>,
146        record: &mut &mut Rv32BranchAdapterRecord,
147    ) -> Self::ReadData {
148        let &Instruction { a, b, d, e, .. } = instruction;
149
150        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
151        debug_assert_eq!(e.as_canonical_u32(), RV32_REGISTER_AS);
152
153        record.rs1_ptr = a.as_canonical_u32();
154        let rs1 = tracing_read(
155            memory,
156            RV32_REGISTER_AS,
157            a.as_canonical_u32(),
158            &mut record.reads_aux[0].prev_timestamp,
159        );
160        record.rs2_ptr = b.as_canonical_u32();
161        let rs2 = tracing_read(
162            memory,
163            RV32_REGISTER_AS,
164            b.as_canonical_u32(),
165            &mut record.reads_aux[1].prev_timestamp,
166        );
167
168        [rs1, rs2]
169    }
170
171    #[inline(always)]
172    fn write(
173        &self,
174        _memory: &mut TracingMemory,
175        _instruction: &Instruction<F>,
176        _data: Self::WriteData,
177        _record: &mut Self::RecordMut<'_>,
178    ) {
179        // This function is intentionally left empty
180    }
181}
182
183impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32BranchAdapterFiller {
184    const WIDTH: usize = size_of::<Rv32BranchAdapterCols<u8>>();
185
186    #[inline(always)]
187    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
188        // SAFETY:
189        // - caller ensures `adapter_row` contains a valid record representation that was previously
190        //   written by the executor
191        // - get_record_from_slice correctly interprets the bytes as Rv32BranchAdapterRecord
192        let record: &Rv32BranchAdapterRecord =
193            unsafe { get_record_from_slice(&mut adapter_row, ()) };
194        let adapter_row: &mut Rv32BranchAdapterCols<F> = adapter_row.borrow_mut();
195
196        // We must assign in reverse
197        let timestamp = record.from_timestamp;
198
199        mem_helper.fill(
200            record.reads_aux[1].prev_timestamp,
201            timestamp + 1,
202            adapter_row.reads_aux[1].as_mut(),
203        );
204
205        mem_helper.fill(
206            record.reads_aux[0].prev_timestamp,
207            timestamp,
208            adapter_row.reads_aux[0].as_mut(),
209        );
210
211        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
212        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
213        adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
214        adapter_row.rs2_ptr = F::from_canonical_u32(record.rs2_ptr);
215    }
216}