openvm_circuit/system/native_adapter/
mod.rs

1pub mod util;
2
3use std::{
4    borrow::{Borrow, BorrowMut},
5    marker::PhantomData,
6};
7
8use openvm_circuit::{
9    arch::{
10        AdapterAirContext, BasicAdapterInterface, ExecutionBridge, ExecutionState,
11        MinimalInstruction, VmAdapterAir,
12    },
13    system::memory::{
14        offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols},
15        MemoryAddress,
16    },
17};
18use openvm_circuit_primitives::AlignedBytesBorrow;
19use openvm_circuit_primitives_derive::AlignedBorrow;
20use openvm_instructions::{
21    instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_IMM_AS, NATIVE_AS,
22};
23use openvm_stark_backend::{
24    interaction::InteractionBuilder,
25    p3_air::BaseAir,
26    p3_field::{Field, FieldAlgebra, PrimeField32},
27};
28use util::{tracing_read_or_imm_native, tracing_write_native};
29
30use super::memory::{online::TracingMemory, MemoryAuxColsFactory};
31use crate::{
32    arch::{get_record_from_slice, AdapterTraceExecutor, AdapterTraceFiller},
33    system::memory::offline_checker::{MemoryReadAuxRecord, MemoryWriteAuxRecord},
34};
35
36#[repr(C)]
37#[derive(AlignedBorrow)]
38pub struct NativeAdapterReadCols<T> {
39    pub address: MemoryAddress<T, T>,
40    pub read_aux: MemoryReadOrImmediateAuxCols<T>,
41}
42
43#[repr(C)]
44#[derive(AlignedBorrow)]
45pub struct NativeAdapterWriteCols<T> {
46    pub address: MemoryAddress<T, T>,
47    pub write_aux: MemoryWriteAuxCols<T, 1>,
48}
49
50#[repr(C)]
51#[derive(AlignedBorrow)]
52pub struct NativeAdapterCols<T, const R: usize, const W: usize> {
53    pub from_state: ExecutionState<T>,
54    pub reads_aux: [NativeAdapterReadCols<T>; R],
55    pub writes_aux: [NativeAdapterWriteCols<T>; W],
56}
57
58#[derive(Clone, Copy, Debug, derive_new::new)]
59pub struct NativeAdapterAir<const R: usize, const W: usize> {
60    pub(super) execution_bridge: ExecutionBridge,
61    pub(super) memory_bridge: MemoryBridge,
62}
63
64impl<F: Field, const R: usize, const W: usize> BaseAir<F> for NativeAdapterAir<R, W> {
65    fn width(&self) -> usize {
66        NativeAdapterCols::<F, R, W>::width()
67    }
68}
69
70impl<AB: InteractionBuilder, const R: usize, const W: usize> VmAdapterAir<AB>
71    for NativeAdapterAir<R, W>
72{
73    type Interface = BasicAdapterInterface<AB::Expr, MinimalInstruction<AB::Expr>, R, W, 1, 1>;
74
75    fn eval(
76        &self,
77        builder: &mut AB,
78        local: &[AB::Var],
79        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
80    ) {
81        let cols: &NativeAdapterCols<_, R, W> = local.borrow();
82        let timestamp = cols.from_state.timestamp;
83        let mut timestamp_delta = 0usize;
84        let mut timestamp_pp = || {
85            timestamp_delta += 1;
86            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
87        };
88
89        for (i, r_cols) in cols.reads_aux.iter().enumerate() {
90            self.memory_bridge
91                .read_or_immediate(
92                    r_cols.address,
93                    ctx.reads[i][0].clone(),
94                    timestamp_pp(),
95                    &r_cols.read_aux,
96                )
97                .eval(builder, ctx.instruction.is_valid.clone());
98        }
99        for (i, w_cols) in cols.writes_aux.iter().enumerate() {
100            self.memory_bridge
101                .write(
102                    w_cols.address,
103                    ctx.writes[i].clone(),
104                    timestamp_pp(),
105                    &w_cols.write_aux,
106                )
107                .eval(builder, ctx.instruction.is_valid.clone());
108        }
109
110        let zero_address =
111            || MemoryAddress::new(AB::Expr::from(AB::F::ZERO), AB::Expr::from(AB::F::ZERO));
112        let f = |var_addr: MemoryAddress<AB::Var, AB::Var>| -> MemoryAddress<AB::Expr, AB::Expr> {
113            MemoryAddress::new(var_addr.address_space.into(), var_addr.pointer.into())
114        };
115
116        let addr_a = if W >= 1 {
117            f(cols.writes_aux[0].address)
118        } else {
119            zero_address()
120        };
121        let addr_b = if R >= 1 {
122            f(cols.reads_aux[0].address)
123        } else {
124            zero_address()
125        };
126        let addr_c = if R >= 2 {
127            f(cols.reads_aux[1].address)
128        } else {
129            zero_address()
130        };
131        self.execution_bridge
132            .execute_and_increment_or_set_pc(
133                ctx.instruction.opcode,
134                [
135                    addr_a.pointer,
136                    addr_b.pointer,
137                    addr_c.pointer,
138                    addr_a.address_space,
139                    addr_b.address_space,
140                    addr_c.address_space,
141                ],
142                cols.from_state,
143                AB::F::from_canonical_usize(timestamp_delta),
144                (DEFAULT_PC_STEP, ctx.to_pc),
145            )
146            .eval(builder, ctx.instruction.is_valid);
147    }
148
149    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
150        let cols: &NativeAdapterCols<_, R, W> = local.borrow();
151        cols.from_state.pc
152    }
153}
154
155#[repr(C)]
156#[derive(AlignedBytesBorrow, Debug)]
157pub struct NativeAdapterRecord<F, const R: usize, const W: usize> {
158    pub from_pc: u32,
159    pub from_timestamp: u32,
160
161    // These are either a pointer to native memory or an immediate value
162    pub read_ptr_or_imm: [F; R],
163    // Will set prev_timestamp to `u32::MAX` if the read is from RV32_IMM_AS
164    pub reads_aux: [MemoryReadAuxRecord; R],
165    pub write_ptr: [F; W],
166    pub writes_aux: [MemoryWriteAuxRecord<F, 1>; W],
167}
168
169/// R reads(R<=2), W writes(W<=1).
170/// Operands: b for the first read, c for the second read, a for the first write.
171/// If an operand is not used, its address space and pointer should be all 0.
172#[derive(Clone, Debug)]
173pub struct NativeAdapterExecutor<F, const R: usize, const W: usize> {
174    _phantom: PhantomData<F>,
175}
176
177impl<F, const R: usize, const W: usize> Default for NativeAdapterExecutor<F, R, W> {
178    fn default() -> Self {
179        Self {
180            _phantom: PhantomData,
181        }
182    }
183}
184
185impl<F, const R: usize, const W: usize> AdapterTraceExecutor<F> for NativeAdapterExecutor<F, R, W>
186where
187    F: PrimeField32,
188{
189    const WIDTH: usize = size_of::<NativeAdapterCols<u8, R, W>>();
190    type ReadData = [[F; 1]; R];
191    type WriteData = [[F; 1]; W];
192    type RecordMut<'a> = &'a mut NativeAdapterRecord<F, R, W>;
193
194    #[inline(always)]
195    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
196        record.from_pc = pc;
197        record.from_timestamp = memory.timestamp;
198    }
199
200    #[inline(always)]
201    fn read(
202        &self,
203        memory: &mut TracingMemory,
204        instruction: &Instruction<F>,
205        record: &mut Self::RecordMut<'_>,
206    ) -> Self::ReadData {
207        debug_assert!(R <= 2);
208        let &Instruction { b, c, e, f, .. } = instruction;
209
210        let mut reads = [[F::ZERO; 1]; R];
211        record
212            .read_ptr_or_imm
213            .iter_mut()
214            .enumerate()
215            .zip(record.reads_aux.iter_mut())
216            .for_each(|((i, ptr_or_imm), read_aux)| {
217                *ptr_or_imm = if i == 0 { b } else { c };
218                let addr_space = if i == 0 { e } else { f };
219                reads[i][0] = tracing_read_or_imm_native(
220                    memory,
221                    addr_space,
222                    *ptr_or_imm,
223                    &mut read_aux.prev_timestamp,
224                );
225            });
226        reads
227    }
228
229    #[inline(always)]
230    fn write(
231        &self,
232        memory: &mut TracingMemory,
233        instruction: &Instruction<F>,
234        data: Self::WriteData,
235        record: &mut Self::RecordMut<'_>,
236    ) {
237        let &Instruction { a, d, .. } = instruction;
238        debug_assert!(W <= 1);
239        debug_assert_eq!(d.as_canonical_u32(), NATIVE_AS);
240
241        if W >= 1 {
242            record.write_ptr[0] = a;
243            tracing_write_native(
244                memory,
245                a.as_canonical_u32(),
246                data[0],
247                &mut record.writes_aux[0].prev_timestamp,
248                &mut record.writes_aux[0].prev_data,
249            );
250        }
251    }
252}
253
254impl<F: PrimeField32, const R: usize, const W: usize> AdapterTraceFiller<F>
255    for NativeAdapterExecutor<F, R, W>
256{
257    const WIDTH: usize = size_of::<NativeAdapterCols<u8, R, W>>();
258
259    #[inline(always)]
260    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
261        // SAFETY:
262        // - caller ensures `adapter_row` contains a valid record representation that was previously
263        //   written by the executor
264        let record: &NativeAdapterRecord<F, R, W> =
265            unsafe { get_record_from_slice(&mut adapter_row, ()) };
266        let adapter_row: &mut NativeAdapterCols<_, R, W> = adapter_row.borrow_mut();
267        // Writing in reverse order to avoid overwriting the `record`
268        if W >= 1 {
269            adapter_row.writes_aux[0]
270                .write_aux
271                .set_prev_data(record.writes_aux[0].prev_data);
272            mem_helper.fill(
273                record.writes_aux[0].prev_timestamp,
274                record.from_timestamp + R as u32,
275                adapter_row.writes_aux[0].write_aux.as_mut(),
276            );
277            adapter_row.writes_aux[0].address.pointer = record.write_ptr[0];
278            adapter_row.writes_aux[0].address.address_space = F::from_canonical_u32(NATIVE_AS);
279        }
280
281        adapter_row
282            .reads_aux
283            .iter_mut()
284            .enumerate()
285            .zip(record.reads_aux.iter().zip(record.read_ptr_or_imm.iter()))
286            .rev()
287            .for_each(|((i, read_cols), (read_record, ptr_or_imm))| {
288                if read_record.prev_timestamp == u32::MAX {
289                    read_cols.read_aux.is_zero_aux = F::ZERO;
290                    read_cols.read_aux.is_immediate = F::ONE;
291                    mem_helper.fill(
292                        0,
293                        record.from_timestamp + i as u32,
294                        read_cols.read_aux.as_mut(),
295                    );
296                    read_cols.address.pointer = *ptr_or_imm;
297                    read_cols.address.address_space = F::from_canonical_u32(RV32_IMM_AS);
298                } else {
299                    read_cols.read_aux.is_zero_aux = F::from_canonical_u32(NATIVE_AS).inverse();
300                    read_cols.read_aux.is_immediate = F::ZERO;
301                    mem_helper.fill(
302                        read_record.prev_timestamp,
303                        record.from_timestamp + i as u32,
304                        read_cols.read_aux.as_mut(),
305                    );
306                    read_cols.address.pointer = *ptr_or_imm;
307                    read_cols.address.address_space = F::from_canonical_u32(NATIVE_AS);
308                }
309            });
310
311        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
312        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
313    }
314}