1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4 arch::{
5 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
6 BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
7 },
8 system::memory::{
9 offline_checker::{MemoryBridge, MemoryWriteAuxCols, MemoryWriteBytesAuxRecord},
10 online::TracingMemory,
11 MemoryAddress, MemoryAuxColsFactory,
12 },
13};
14use openvm_circuit_primitives::{utils::not, AlignedBytesBorrow};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
18};
19use openvm_stark_backend::{
20 interaction::InteractionBuilder,
21 p3_air::{AirBuilder, BaseAir},
22 p3_field::{Field, FieldAlgebra, PrimeField32},
23};
24
25use super::RV32_REGISTER_NUM_LIMBS;
26use crate::adapters::tracing_write;
27
28#[repr(C)]
29#[derive(Debug, Clone, AlignedBorrow)]
30pub struct Rv32RdWriteAdapterCols<T> {
31 pub from_state: ExecutionState<T>,
32 pub rd_ptr: T,
33 pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
34}
35
36#[repr(C)]
37#[derive(Debug, Clone, AlignedBorrow)]
38pub struct Rv32CondRdWriteAdapterCols<T> {
39 pub inner: Rv32RdWriteAdapterCols<T>,
40 pub needs_write: T,
41}
42
43#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32RdWriteAdapterAir {
46 pub(super) memory_bridge: MemoryBridge,
47 pub(super) execution_bridge: ExecutionBridge,
48}
49
50#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32CondRdWriteAdapterAir {
53 inner: Rv32RdWriteAdapterAir,
54}
55
56impl<F: Field> BaseAir<F> for Rv32RdWriteAdapterAir {
57 fn width(&self) -> usize {
58 Rv32RdWriteAdapterCols::<F>::width()
59 }
60}
61
62impl<F: Field> BaseAir<F> for Rv32CondRdWriteAdapterAir {
63 fn width(&self) -> usize {
64 Rv32CondRdWriteAdapterCols::<F>::width()
65 }
66}
67
68impl Rv32RdWriteAdapterAir {
69 #[allow(clippy::type_complexity)]
78 fn conditional_eval<AB: InteractionBuilder>(
79 &self,
80 builder: &mut AB,
81 local_cols: &Rv32RdWriteAdapterCols<AB::Var>,
82 ctx: AdapterAirContext<
83 AB::Expr,
84 BasicAdapterInterface<
85 AB::Expr,
86 ImmInstruction<AB::Expr>,
87 0,
88 1,
89 0,
90 RV32_REGISTER_NUM_LIMBS,
91 >,
92 >,
93 needs_write: Option<AB::Expr>,
94 ) {
95 let timestamp: AB::Var = local_cols.from_state.timestamp;
96 let timestamp_delta = 1;
97 let (write_count, f) = if let Some(needs_write) = needs_write {
98 (needs_write.clone(), needs_write)
99 } else {
100 (ctx.instruction.is_valid.clone(), AB::Expr::ZERO)
101 };
102 self.memory_bridge
103 .write(
104 MemoryAddress::new(
105 AB::F::from_canonical_u32(RV32_REGISTER_AS),
106 local_cols.rd_ptr,
107 ),
108 ctx.writes[0].clone(),
109 timestamp,
110 &local_cols.rd_aux_cols,
111 )
112 .eval(builder, write_count);
113
114 let to_pc = ctx
115 .to_pc
116 .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
117 self.execution_bridge
119 .execute(
120 ctx.instruction.opcode,
121 [
122 local_cols.rd_ptr.into(),
123 AB::Expr::ZERO,
124 ctx.instruction.immediate,
125 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
126 AB::Expr::ZERO,
127 f,
128 ],
129 local_cols.from_state,
130 ExecutionState {
131 pc: to_pc,
132 timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
133 },
134 )
135 .eval(builder, ctx.instruction.is_valid);
136 }
137}
138
139impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32RdWriteAdapterAir {
140 type Interface =
141 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
142
143 fn eval(
144 &self,
145 builder: &mut AB,
146 local: &[AB::Var],
147 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
148 ) {
149 let local_cols: &Rv32RdWriteAdapterCols<AB::Var> = (*local).borrow();
150 self.conditional_eval(builder, local_cols, ctx, None);
151 }
152
153 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
154 let cols: &Rv32RdWriteAdapterCols<_> = local.borrow();
155 cols.from_state.pc
156 }
157}
158
159impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32CondRdWriteAdapterAir {
160 type Interface =
161 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
162
163 fn eval(
164 &self,
165 builder: &mut AB,
166 local: &[AB::Var],
167 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
168 ) {
169 let local_cols: &Rv32CondRdWriteAdapterCols<AB::Var> = (*local).borrow();
170
171 builder.assert_bool(local_cols.needs_write);
172 builder
173 .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
174 .assert_zero(local_cols.needs_write);
175
176 self.inner.conditional_eval(
177 builder,
178 &local_cols.inner,
179 ctx,
180 Some(local_cols.needs_write.into()),
181 );
182 }
183
184 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
185 let cols: &Rv32CondRdWriteAdapterCols<_> = local.borrow();
186 cols.inner.from_state.pc
187 }
188}
189
190#[repr(C)]
192#[derive(AlignedBytesBorrow, Debug, Clone)]
193pub struct Rv32RdWriteAdapterRecord {
194 pub from_pc: u32,
195 pub from_timestamp: u32,
196
197 pub rd_ptr: u32,
199 pub rd_aux_record: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
200}
201
202#[derive(Clone, Copy, derive_new::new)]
203pub struct Rv32RdWriteAdapterExecutor;
204
205#[derive(Clone, Copy, derive_new::new)]
206pub struct Rv32RdWriteAdapterFiller;
207
208impl<F> AdapterTraceExecutor<F> for Rv32RdWriteAdapterExecutor
209where
210 F: PrimeField32,
211{
212 const WIDTH: usize = size_of::<Rv32RdWriteAdapterCols<u8>>();
213 type ReadData = ();
214 type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
215 type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord;
216
217 #[inline(always)]
218 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
219 record.from_pc = pc;
220 record.from_timestamp = memory.timestamp;
221 }
222
223 #[inline(always)]
224 fn read(
225 &self,
226 _memory: &mut TracingMemory,
227 _instruction: &Instruction<F>,
228 _record: &mut Self::RecordMut<'_>,
229 ) -> Self::ReadData {
230 }
232
233 #[inline(always)]
234 fn write(
235 &self,
236 memory: &mut TracingMemory,
237 instruction: &Instruction<F>,
238 data: Self::WriteData,
239 record: &mut Self::RecordMut<'_>,
240 ) {
241 let &Instruction { a, d, .. } = instruction;
242
243 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
244
245 record.rd_ptr = a.as_canonical_u32();
246 tracing_write(
247 memory,
248 RV32_REGISTER_AS,
249 record.rd_ptr,
250 data,
251 &mut record.rd_aux_record.prev_timestamp,
252 &mut record.rd_aux_record.prev_data,
253 );
254 }
255}
256
257impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32RdWriteAdapterFiller {
258 const WIDTH: usize = size_of::<Rv32RdWriteAdapterCols<u8>>();
259
260 #[inline(always)]
261 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
262 let record: &Rv32RdWriteAdapterRecord =
267 unsafe { get_record_from_slice(&mut adapter_row, ()) };
268 let adapter_row: &mut Rv32RdWriteAdapterCols<F> = adapter_row.borrow_mut();
269
270 adapter_row
271 .rd_aux_cols
272 .set_prev_data(record.rd_aux_record.prev_data.map(F::from_canonical_u8));
273 mem_helper.fill(
274 record.rd_aux_record.prev_timestamp,
275 record.from_timestamp,
276 adapter_row.rd_aux_cols.as_mut(),
277 );
278 adapter_row.rd_ptr = F::from_canonical_u32(record.rd_ptr);
279 adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
280 adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
281 }
282}
283
284#[derive(Clone, Copy, derive_new::new)]
286pub struct Rv32CondRdWriteAdapterExecutor {
287 inner: Rv32RdWriteAdapterExecutor,
288}
289
290#[derive(Clone, Copy, derive_new::new)]
291pub struct Rv32CondRdWriteAdapterFiller {
292 inner: Rv32RdWriteAdapterFiller,
293}
294
295impl<F> AdapterTraceExecutor<F> for Rv32CondRdWriteAdapterExecutor
296where
297 F: PrimeField32,
298{
299 const WIDTH: usize = size_of::<Rv32CondRdWriteAdapterCols<u8>>();
300 type ReadData = ();
301 type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
302 type RecordMut<'a> = &'a mut Rv32RdWriteAdapterRecord;
303
304 #[inline(always)]
305 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
306 record.from_pc = pc;
307 record.from_timestamp = memory.timestamp;
308 }
309
310 #[inline(always)]
311 fn read(
312 &self,
313 memory: &mut TracingMemory,
314 instruction: &Instruction<F>,
315 record: &mut Self::RecordMut<'_>,
316 ) -> Self::ReadData {
317 <Rv32RdWriteAdapterExecutor as AdapterTraceExecutor<F>>::read(
318 &self.inner,
319 memory,
320 instruction,
321 record,
322 )
323 }
324
325 #[inline(always)]
326 fn write(
327 &self,
328 memory: &mut TracingMemory,
329 instruction: &Instruction<F>,
330 data: Self::WriteData,
331 record: &mut Self::RecordMut<'_>,
332 ) {
333 let Instruction { f: enabled, .. } = instruction;
334
335 if enabled.is_one() {
336 <Rv32RdWriteAdapterExecutor as AdapterTraceExecutor<F>>::write(
337 &self.inner,
338 memory,
339 instruction,
340 data,
341 record,
342 );
343 } else {
344 memory.increment_timestamp();
345 record.rd_ptr = u32::MAX;
346 }
347 }
348}
349
350impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32CondRdWriteAdapterFiller {
351 const WIDTH: usize = size_of::<Rv32CondRdWriteAdapterCols<u8>>();
352
353 #[inline(always)]
354 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
355 let record: &Rv32RdWriteAdapterRecord =
360 unsafe { get_record_from_slice(&mut adapter_row, ()) };
361 let adapter_cols: &mut Rv32CondRdWriteAdapterCols<F> = adapter_row.borrow_mut();
362
363 adapter_cols.needs_write = F::from_bool(record.rd_ptr != u32::MAX);
364
365 if record.rd_ptr != u32::MAX {
366 unsafe {
370 self.inner.fill_trace_row(
371 mem_helper,
372 adapter_row
373 .split_at_mut_unchecked(size_of::<Rv32RdWriteAdapterCols<u8>>())
374 .0,
375 )
376 };
377 } else {
378 adapter_cols.inner.rd_ptr = F::ZERO;
379 mem_helper.fill_zero(adapter_cols.inner.rd_aux_cols.as_mut());
380 adapter_cols.inner.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
381 adapter_cols.inner.from_state.pc = F::from_canonical_u32(record.from_pc);
382 }
383 }
384}