openvm_native_circuit/adapters/
native_vectorized_adapter.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use openvm_circuit::{
7    arch::{
8        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9        BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
10    },
11    system::{
12        memory::{
13            offline_checker::{
14                MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
15                MemoryWriteAuxRecord,
16            },
17            online::TracingMemory,
18            MemoryAddress, MemoryAuxColsFactory,
19        },
20        native_adapter::util::{tracing_read_native, tracing_write_native},
21    },
22};
23use openvm_circuit_primitives::AlignedBytesBorrow;
24use openvm_circuit_primitives_derive::AlignedBorrow;
25use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP};
26use openvm_native_compiler::conversion::AS;
27use openvm_stark_backend::{
28    interaction::InteractionBuilder,
29    p3_air::BaseAir,
30    p3_field::{Field, FieldAlgebra, PrimeField32},
31};
32
33#[repr(C)]
34#[derive(AlignedBorrow)]
35pub struct NativeVectorizedAdapterCols<T, const N: usize> {
36    pub from_state: ExecutionState<T>,
37    pub a_pointer: T,
38    pub b_pointer: T,
39    pub c_pointer: T,
40    pub reads_aux: [MemoryReadAuxCols<T>; 2],
41    pub writes_aux: [MemoryWriteAuxCols<T, N>; 1],
42}
43
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct NativeVectorizedAdapterAir<const N: usize> {
46    pub(super) execution_bridge: ExecutionBridge,
47    pub(super) memory_bridge: MemoryBridge,
48}
49
50impl<F: Field, const N: usize> BaseAir<F> for NativeVectorizedAdapterAir<N> {
51    fn width(&self) -> usize {
52        NativeVectorizedAdapterCols::<F, N>::width()
53    }
54}
55
56impl<AB: InteractionBuilder, const N: usize> VmAdapterAir<AB> for NativeVectorizedAdapterAir<N> {
57    type Interface = BasicAdapterInterface<AB::Expr, MinimalInstruction<AB::Expr>, 2, 1, N, N>;
58
59    fn eval(
60        &self,
61        builder: &mut AB,
62        local: &[AB::Var],
63        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
64    ) {
65        let cols: &NativeVectorizedAdapterCols<_, N> = local.borrow();
66        let timestamp = cols.from_state.timestamp;
67        let mut timestamp_delta = 0usize;
68        let mut timestamp_pp = || {
69            timestamp_delta += 1;
70            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
71        };
72
73        let native_as = AB::Expr::from_canonical_u32(AS::Native as u32);
74
75        self.memory_bridge
76            .read(
77                MemoryAddress::new(native_as.clone(), cols.b_pointer),
78                ctx.reads[0].clone(),
79                timestamp_pp(),
80                &cols.reads_aux[0],
81            )
82            .eval(builder, ctx.instruction.is_valid.clone());
83
84        self.memory_bridge
85            .read(
86                MemoryAddress::new(native_as.clone(), cols.c_pointer),
87                ctx.reads[1].clone(),
88                timestamp_pp(),
89                &cols.reads_aux[1],
90            )
91            .eval(builder, ctx.instruction.is_valid.clone());
92
93        self.memory_bridge
94            .write(
95                MemoryAddress::new(native_as.clone(), cols.a_pointer),
96                ctx.writes[0].clone(),
97                timestamp_pp(),
98                &cols.writes_aux[0],
99            )
100            .eval(builder, ctx.instruction.is_valid.clone());
101
102        self.execution_bridge
103            .execute_and_increment_or_set_pc(
104                ctx.instruction.opcode,
105                [
106                    cols.a_pointer.into(),
107                    cols.b_pointer.into(),
108                    cols.c_pointer.into(),
109                    native_as.clone(),
110                    native_as.clone(),
111                ],
112                cols.from_state,
113                AB::F::from_canonical_usize(timestamp_delta),
114                (DEFAULT_PC_STEP, ctx.to_pc),
115            )
116            .eval(builder, ctx.instruction.is_valid);
117    }
118
119    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
120        let cols: &NativeVectorizedAdapterCols<_, N> = local.borrow();
121        cols.from_state.pc
122    }
123}
124
125#[repr(C)]
126#[derive(AlignedBytesBorrow, Debug)]
127pub struct NativeVectorizedAdapterRecord<F, const N: usize> {
128    pub from_pc: u32,
129    pub from_timestamp: u32,
130    pub a_ptr: F,
131    pub b_ptr: F,
132    pub c_ptr: F,
133    pub reads_aux: [MemoryReadAuxRecord; 2],
134    pub write_aux: MemoryWriteAuxRecord<F, N>,
135}
136
137#[derive(derive_new::new, Clone, Copy)]
138pub struct NativeVectorizedAdapterExecutor<const N: usize>;
139
140#[derive(derive_new::new)]
141pub struct NativeVectorizedAdapterFiller<const N: usize>;
142
143impl<F: PrimeField32, const N: usize> AdapterTraceExecutor<F>
144    for NativeVectorizedAdapterExecutor<N>
145{
146    const WIDTH: usize = size_of::<NativeVectorizedAdapterCols<u8, N>>();
147    type ReadData = [[F; N]; 2];
148    type WriteData = [F; N];
149    type RecordMut<'a> = &'a mut NativeVectorizedAdapterRecord<F, N>;
150
151    #[inline(always)]
152    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
153        record.from_pc = pc;
154        record.from_timestamp = memory.timestamp();
155    }
156
157    #[inline(always)]
158    fn read(
159        &self,
160        memory: &mut TracingMemory,
161        instruction: &Instruction<F>,
162        record: &mut Self::RecordMut<'_>,
163    ) -> Self::ReadData {
164        let &Instruction { b, c, d, e, .. } = instruction;
165        debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32);
166        debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32);
167
168        record.b_ptr = b;
169        let b_val = tracing_read_native(
170            memory,
171            b.as_canonical_u32(),
172            &mut record.reads_aux[0].prev_timestamp,
173        );
174        record.c_ptr = c;
175        let c_val = tracing_read_native(
176            memory,
177            c.as_canonical_u32(),
178            &mut record.reads_aux[1].prev_timestamp,
179        );
180
181        [b_val, c_val]
182    }
183
184    #[inline(always)]
185    fn write(
186        &self,
187        memory: &mut TracingMemory,
188        instruction: &Instruction<F>,
189        data: Self::WriteData,
190        record: &mut Self::RecordMut<'_>,
191    ) {
192        let &Instruction { a, d, .. } = instruction;
193
194        debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32);
195
196        record.a_ptr = a;
197        tracing_write_native(
198            memory,
199            a.as_canonical_u32(),
200            data,
201            &mut record.write_aux.prev_timestamp,
202            &mut record.write_aux.prev_data,
203        );
204    }
205}
206
207impl<F: PrimeField32, const N: usize> AdapterTraceFiller<F> for NativeVectorizedAdapterFiller<N> {
208    const WIDTH: usize = size_of::<NativeVectorizedAdapterCols<u8, N>>();
209
210    #[inline(always)]
211    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
212        // SAFETY:
213        // - caller ensures `adapter_row` contains a valid record representation that was previously
214        //   written by the executor
215        let record: &NativeVectorizedAdapterRecord<F, N> =
216            unsafe { get_record_from_slice(&mut adapter_row, ()) };
217        let adapter_row: &mut NativeVectorizedAdapterCols<F, N> = adapter_row.borrow_mut();
218
219        // Writing in reverse order to avoid overwriting the `record`
220        adapter_row.writes_aux[0].set_prev_data(record.write_aux.prev_data);
221        mem_helper.fill(
222            record.write_aux.prev_timestamp,
223            record.from_timestamp + 2,
224            adapter_row.writes_aux[0].as_mut(),
225        );
226
227        adapter_row
228            .reads_aux
229            .iter_mut()
230            .enumerate()
231            .zip(record.reads_aux.iter())
232            .rev()
233            .for_each(|((i, read_cols), read_record)| {
234                mem_helper.fill(
235                    read_record.prev_timestamp,
236                    record.from_timestamp + i as u32,
237                    read_cols.as_mut(),
238                );
239            });
240
241        adapter_row.c_pointer = record.c_ptr;
242        adapter_row.b_pointer = record.b_ptr;
243        adapter_row.a_pointer = record.a_ptr;
244
245        adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
246        adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
247    }
248}