openvm_native_circuit/adapters/
branch_native_adapter.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{
7    arch::{
8        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
10    },
11    system::{
12        memory::{
13            offline_checker::{MemoryBridge, MemoryReadAuxRecord, MemoryReadOrImmediateAuxCols},
14            online::TracingMemory,
15            MemoryAddress, MemoryAuxColsFactory,
16        },
17        native_adapter::util::tracing_read_or_imm_native,
18    },
19};
20use openvm_circuit_primitives::AlignedBytesBorrow;
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP};
23use openvm_native_compiler::conversion::AS;
24use openvm_stark_backend::{
25    interaction::InteractionBuilder,
26    p3_air::BaseAir,
27    p3_field::{Field, FieldAlgebra, PrimeField32},
28};
29
30#[repr(C)]
31#[derive(AlignedBorrow, Debug)]
32pub struct BranchNativeAdapterReadCols<T> {
33    pub address: MemoryAddress<T, T>,
34    pub read_aux: MemoryReadOrImmediateAuxCols<T>,
35}
36
37#[repr(C)]
38#[derive(AlignedBorrow, Debug)]
39pub struct BranchNativeAdapterCols<T> {
40    pub from_state: ExecutionState<T>,
41    pub reads_aux: [BranchNativeAdapterReadCols<T>; 2],
42}
43
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct BranchNativeAdapterAir {
46    pub(super) execution_bridge: ExecutionBridge,
47    pub(super) memory_bridge: MemoryBridge,
48}
49
50impl<F: Field> BaseAir<F> for BranchNativeAdapterAir {
51    fn width(&self) -> usize {
52        BranchNativeAdapterCols::<F>::width()
53    }
54}
55
56impl<AB: InteractionBuilder> VmAdapterAir<AB> for BranchNativeAdapterAir {
57    type Interface = BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 2, 0, 1, 1>;
58
59    fn eval(
60        &self,
61        builder: &mut AB,
62        local: &[AB::Var],
63        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
64    ) {
65        let cols: &BranchNativeAdapterCols<_> = local.borrow();
66        let timestamp = cols.from_state.timestamp;
67        let mut timestamp_delta = 0usize;
68        let mut timestamp_pp = || {
69            timestamp_delta += 1;
70            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
71        };
72
73        // check that d and e are in {0, 4}
74        let d = cols.reads_aux[0].address.address_space;
75        let e = cols.reads_aux[1].address.address_space;
76        builder.assert_eq(
77            d * (d - AB::F::from_canonical_u32(AS::Native as u32)),
78            AB::F::ZERO,
79        );
80        builder.assert_eq(
81            e * (e - AB::F::from_canonical_u32(AS::Native as u32)),
82            AB::F::ZERO,
83        );
84
85        self.memory_bridge
86            .read_or_immediate(
87                cols.reads_aux[0].address,
88                ctx.reads[0][0].clone(),
89                timestamp_pp(),
90                &cols.reads_aux[0].read_aux,
91            )
92            .eval(builder, ctx.instruction.is_valid.clone());
93
94        self.memory_bridge
95            .read_or_immediate(
96                cols.reads_aux[1].address,
97                ctx.reads[1][0].clone(),
98                timestamp_pp(),
99                &cols.reads_aux[1].read_aux,
100            )
101            .eval(builder, ctx.instruction.is_valid.clone());
102
103        self.execution_bridge
104            .execute_and_increment_or_set_pc(
105                ctx.instruction.opcode,
106                [
107                    cols.reads_aux[0].address.pointer.into(),
108                    cols.reads_aux[1].address.pointer.into(),
109                    ctx.instruction.immediate,
110                    cols.reads_aux[0].address.address_space.into(),
111                    cols.reads_aux[1].address.address_space.into(),
112                ],
113                cols.from_state,
114                AB::F::from_canonical_usize(timestamp_delta),
115                (DEFAULT_PC_STEP, ctx.to_pc),
116            )
117            .eval(builder, ctx.instruction.is_valid);
118    }
119
120    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
121        let cols: &BranchNativeAdapterCols<_> = local.borrow();
122        cols.from_state.pc
123    }
124}
125
126#[repr(C)]
127#[derive(AlignedBytesBorrow, Debug)]
128pub struct BranchNativeAdapterRecord<F> {
129    pub from_pc: u32,
130    pub from_timestamp: u32,
131
132    pub ptrs: [F; 2],
133    // Will set prev_timestamp to `u32::MAX` if the read is an immediate
134    pub reads_aux: [MemoryReadAuxRecord; 2],
135}
136
137#[derive(derive_new::new, Clone, Copy)]
138pub struct BranchNativeAdapterExecutor;
139
140#[derive(derive_new::new)]
141pub struct BranchNativeAdapterFiller;
142
143impl<F> AdapterTraceExecutor<F> for BranchNativeAdapterExecutor
144where
145    F: PrimeField32,
146{
147    const WIDTH: usize = size_of::<BranchNativeAdapterCols<u8>>();
148    type ReadData = [F; 2];
149    type WriteData = ();
150    type RecordMut<'a> = &'a mut BranchNativeAdapterRecord<F>;
151
152    #[inline(always)]
153    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
154        record.from_pc = pc;
155        record.from_timestamp = memory.timestamp;
156    }
157
158    #[inline(always)]
159    fn read(
160        &self,
161        memory: &mut TracingMemory,
162        instruction: &Instruction<F>,
163        record: &mut Self::RecordMut<'_>,
164    ) -> Self::ReadData {
165        let &Instruction { a, b, d, e, .. } = instruction;
166
167        record.ptrs[0] = a;
168        let rs1 = tracing_read_or_imm_native(memory, d, a, &mut record.reads_aux[0].prev_timestamp);
169        record.ptrs[1] = b;
170        let rs2 = tracing_read_or_imm_native(memory, e, b, &mut record.reads_aux[1].prev_timestamp);
171        [rs1, rs2]
172    }
173
174    #[inline(always)]
175    fn write(
176        &self,
177        _memory: &mut TracingMemory,
178        _instruction: &Instruction<F>,
179        _data: Self::WriteData,
180        _record: &mut Self::RecordMut<'_>,
181    ) {
182        // This adapter doesn't write anything
183    }
184}
185
186impl<F: PrimeField32> AdapterTraceFiller<F> for BranchNativeAdapterFiller {
187    const WIDTH: usize = size_of::<BranchNativeAdapterCols<u8>>();
188
189    #[inline(always)]
190    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
191        // SAFETY:
192        // - caller ensures `adapter_row` contains a valid record representation that was previously
193        //   written by the executor
194        let record: &BranchNativeAdapterRecord<F> =
195            unsafe { get_record_from_slice(&mut adapter_row, ()) };
196        let adapter_row: &mut BranchNativeAdapterCols<F> = adapter_row.borrow_mut();
197
198        // Writing in reverse order to avoid overwriting the `record`
199
200        let native_as = F::from_canonical_u32(AS::Native as u32);
201        for ((i, read_record), read_cols) in record
202            .reads_aux
203            .iter()
204            .enumerate()
205            .zip(adapter_row.reads_aux.iter_mut())
206            .rev()
207        {
208            // previous timestamp is u32::MAX if the read is an immediate
209            if read_record.prev_timestamp == u32::MAX {
210                read_cols.read_aux.is_zero_aux = F::ZERO;
211                read_cols.read_aux.is_immediate = F::ONE;
212                mem_helper.fill(
213                    0,
214                    record.from_timestamp + i as u32,
215                    read_cols.read_aux.as_mut(),
216                );
217                read_cols.address.pointer = record.ptrs[i];
218                read_cols.address.address_space = F::ZERO;
219            } else {
220                read_cols.read_aux.is_zero_aux = native_as.inverse();
221                read_cols.read_aux.is_immediate = F::ZERO;
222                mem_helper.fill(
223                    read_record.prev_timestamp,
224                    record.from_timestamp + i as u32,
225                    read_cols.read_aux.as_mut(),
226                );
227                read_cols.address.pointer = record.ptrs[i];
228                read_cols.address.address_space = native_as;
229            }
230        }
231
232        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
233        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
234    }
235}