openvm_native_circuit/adapters/
branch_native_adapter.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{
7    arch::{
8        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
10    },
11    system::{
12        memory::{
13            offline_checker::{MemoryBridge, MemoryReadAuxRecord, MemoryReadOrImmediateAuxCols},
14            online::TracingMemory,
15            MemoryAddress, MemoryAuxColsFactory,
16        },
17        native_adapter::util::tracing_read_or_imm_native,
18    },
19};
20use openvm_circuit_primitives::AlignedBytesBorrow;
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP};
23use openvm_native_compiler::conversion::AS;
24use openvm_stark_backend::{
25    interaction::InteractionBuilder,
26    p3_air::BaseAir,
27    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
28};
29
30#[repr(C)]
31#[derive(AlignedBorrow, Debug)]
32pub struct BranchNativeAdapterReadCols<T> {
33    pub address: MemoryAddress<T, T>,
34    pub read_aux: MemoryReadOrImmediateAuxCols<T>,
35}
36
37#[repr(C)]
38#[derive(AlignedBorrow, Debug)]
39pub struct BranchNativeAdapterCols<T> {
40    pub from_state: ExecutionState<T>,
41    pub reads_aux: [BranchNativeAdapterReadCols<T>; 2],
42}
43
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct BranchNativeAdapterAir {
46    pub(super) execution_bridge: ExecutionBridge,
47    pub(super) memory_bridge: MemoryBridge,
48}
49
50impl<F: Field> BaseAir<F> for BranchNativeAdapterAir {
51    fn width(&self) -> usize {
52        BranchNativeAdapterCols::<F>::width()
53    }
54}
55
56impl<AB: InteractionBuilder> VmAdapterAir<AB> for BranchNativeAdapterAir {
57    type Interface = BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 2, 0, 1, 1>;
58
59    fn eval(
60        &self,
61        builder: &mut AB,
62        local: &[AB::Var],
63        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
64    ) {
65        let cols: &BranchNativeAdapterCols<_> = local.borrow();
66        let timestamp = cols.from_state.timestamp;
67        let mut timestamp_delta = 0usize;
68        let mut timestamp_pp = || {
69            timestamp_delta += 1;
70            timestamp + AB::F::from_usize(timestamp_delta - 1)
71        };
72
73        // check that d and e are in {0, 4}
74        let d = cols.reads_aux[0].address.address_space;
75        let e = cols.reads_aux[1].address.address_space;
76        builder.assert_eq(d * (d - AB::F::from_u32(AS::Native as u32)), AB::F::ZERO);
77        builder.assert_eq(e * (e - AB::F::from_u32(AS::Native as u32)), AB::F::ZERO);
78
79        self.memory_bridge
80            .read_or_immediate(
81                cols.reads_aux[0].address,
82                ctx.reads[0][0].clone(),
83                timestamp_pp(),
84                &cols.reads_aux[0].read_aux,
85            )
86            .eval(builder, ctx.instruction.is_valid.clone());
87
88        self.memory_bridge
89            .read_or_immediate(
90                cols.reads_aux[1].address,
91                ctx.reads[1][0].clone(),
92                timestamp_pp(),
93                &cols.reads_aux[1].read_aux,
94            )
95            .eval(builder, ctx.instruction.is_valid.clone());
96
97        self.execution_bridge
98            .execute_and_increment_or_set_pc(
99                ctx.instruction.opcode,
100                [
101                    cols.reads_aux[0].address.pointer.into(),
102                    cols.reads_aux[1].address.pointer.into(),
103                    ctx.instruction.immediate,
104                    cols.reads_aux[0].address.address_space.into(),
105                    cols.reads_aux[1].address.address_space.into(),
106                ],
107                cols.from_state,
108                AB::F::from_usize(timestamp_delta),
109                (DEFAULT_PC_STEP, ctx.to_pc),
110            )
111            .eval(builder, ctx.instruction.is_valid);
112    }
113
114    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
115        let cols: &BranchNativeAdapterCols<_> = local.borrow();
116        cols.from_state.pc
117    }
118}
119
120#[repr(C)]
121#[derive(AlignedBytesBorrow, Debug)]
122pub struct BranchNativeAdapterRecord<F> {
123    pub from_pc: u32,
124    pub from_timestamp: u32,
125
126    pub ptrs: [F; 2],
127    // Will set prev_timestamp to `u32::MAX` if the read is an immediate
128    pub reads_aux: [MemoryReadAuxRecord; 2],
129}
130
131#[derive(derive_new::new, Clone, Copy)]
132pub struct BranchNativeAdapterExecutor;
133
134#[derive(derive_new::new)]
135pub struct BranchNativeAdapterFiller;
136
137impl<F> AdapterTraceExecutor<F> for BranchNativeAdapterExecutor
138where
139    F: PrimeField32,
140{
141    const WIDTH: usize = size_of::<BranchNativeAdapterCols<u8>>();
142    type ReadData = [F; 2];
143    type WriteData = ();
144    type RecordMut<'a> = &'a mut BranchNativeAdapterRecord<F>;
145
146    #[inline(always)]
147    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
148        record.from_pc = pc;
149        record.from_timestamp = memory.timestamp;
150    }
151
152    #[inline(always)]
153    fn read(
154        &self,
155        memory: &mut TracingMemory,
156        instruction: &Instruction<F>,
157        record: &mut Self::RecordMut<'_>,
158    ) -> Self::ReadData {
159        let &Instruction { a, b, d, e, .. } = instruction;
160
161        record.ptrs[0] = a;
162        let rs1 = tracing_read_or_imm_native(memory, d, a, &mut record.reads_aux[0].prev_timestamp);
163        record.ptrs[1] = b;
164        let rs2 = tracing_read_or_imm_native(memory, e, b, &mut record.reads_aux[1].prev_timestamp);
165        [rs1, rs2]
166    }
167
168    #[inline(always)]
169    fn write(
170        &self,
171        _memory: &mut TracingMemory,
172        _instruction: &Instruction<F>,
173        _data: Self::WriteData,
174        _record: &mut Self::RecordMut<'_>,
175    ) {
176        // This adapter doesn't write anything
177    }
178}
179
180impl<F: PrimeField32> AdapterTraceFiller<F> for BranchNativeAdapterFiller {
181    const WIDTH: usize = size_of::<BranchNativeAdapterCols<u8>>();
182
183    #[inline(always)]
184    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
185        // SAFETY:
186        // - caller ensures `adapter_row` contains a valid record representation that was previously
187        //   written by the executor
188        let record: &BranchNativeAdapterRecord<F> =
189            unsafe { get_record_from_slice(&mut adapter_row, ()) };
190        let adapter_row: &mut BranchNativeAdapterCols<F> = adapter_row.borrow_mut();
191
192        // Writing in reverse order to avoid overwriting the `record`
193
194        let native_as = F::from_u32(AS::Native as u32);
195        for ((i, read_record), read_cols) in record
196            .reads_aux
197            .iter()
198            .enumerate()
199            .zip(adapter_row.reads_aux.iter_mut())
200            .rev()
201        {
202            // previous timestamp is u32::MAX if the read is an immediate
203            if read_record.prev_timestamp == u32::MAX {
204                read_cols.read_aux.is_zero_aux = F::ZERO;
205                read_cols.read_aux.is_immediate = F::ONE;
206                mem_helper.fill(
207                    0,
208                    record.from_timestamp + i as u32,
209                    read_cols.read_aux.as_mut(),
210                );
211                read_cols.address.pointer = record.ptrs[i];
212                read_cols.address.address_space = F::ZERO;
213            } else {
214                read_cols.read_aux.is_zero_aux = native_as.inverse();
215                read_cols.read_aux.is_immediate = F::ZERO;
216                mem_helper.fill(
217                    read_record.prev_timestamp,
218                    record.from_timestamp + i as u32,
219                    read_cols.read_aux.as_mut(),
220                );
221                read_cols.address.pointer = record.ptrs[i];
222                read_cols.address.address_space = native_as;
223            }
224        }
225
226        adapter_row.from_state.timestamp = F::from_u32(record.from_timestamp);
227        adapter_row.from_state.pc = F::from_u32(record.from_pc);
228    }
229}