1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4};
5
6use openvm_circuit::{
7 arch::{
8 AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9 ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip,
10 VmAdapterInterface,
11 },
12 system::{
13 memory::{
14 offline_checker::{MemoryBridge, MemoryWriteAuxCols},
15 MemoryAddress, MemoryController, OfflineMemory, RecordId,
16 },
17 program::ProgramBus,
18 },
19};
20use openvm_circuit_primitives::utils::not;
21use openvm_circuit_primitives_derive::AlignedBorrow;
22use openvm_instructions::{
23 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_AS,
24};
25use openvm_stark_backend::{
26 interaction::InteractionBuilder,
27 p3_air::{AirBuilder, BaseAir},
28 p3_field::{Field, FieldAlgebra, PrimeField32},
29};
30use serde::{Deserialize, Serialize};
31
32use super::RV32_REGISTER_NUM_LIMBS;
33
34#[derive(Debug)]
36pub struct Rv32RdWriteAdapterChip<F: Field> {
37 pub air: Rv32RdWriteAdapterAir,
38 _marker: PhantomData<F>,
39}
40
41#[derive(Debug)]
43pub struct Rv32CondRdWriteAdapterChip<F: Field> {
44 inner: Rv32RdWriteAdapterChip<F>,
46 pub air: Rv32CondRdWriteAdapterAir,
47}
48
49impl<F: PrimeField32> Rv32RdWriteAdapterChip<F> {
50 pub fn new(
51 execution_bus: ExecutionBus,
52 program_bus: ProgramBus,
53 memory_bridge: MemoryBridge,
54 ) -> Self {
55 Self {
56 air: Rv32RdWriteAdapterAir {
57 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
58 memory_bridge,
59 },
60 _marker: PhantomData,
61 }
62 }
63}
64
65impl<F: PrimeField32> Rv32CondRdWriteAdapterChip<F> {
66 pub fn new(
67 execution_bus: ExecutionBus,
68 program_bus: ProgramBus,
69 memory_bridge: MemoryBridge,
70 ) -> Self {
71 let inner = Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge);
72 let air = Rv32CondRdWriteAdapterAir { inner: inner.air };
73 Self { inner, air }
74 }
75}
76
77#[repr(C)]
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct Rv32RdWriteWriteRecord {
80 pub from_state: ExecutionState<u32>,
81 pub rd_id: Option<RecordId>,
82}
83
84#[repr(C)]
85#[derive(Debug, Clone, AlignedBorrow)]
86pub struct Rv32RdWriteAdapterCols<T> {
87 pub from_state: ExecutionState<T>,
88 pub rd_ptr: T,
89 pub rd_aux_cols: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
90}
91
92#[repr(C)]
93#[derive(Debug, Clone, AlignedBorrow)]
94pub struct Rv32CondRdWriteAdapterCols<T> {
95 inner: Rv32RdWriteAdapterCols<T>,
96 pub needs_write: T,
97}
98
99#[derive(Clone, Copy, Debug, derive_new::new)]
100pub struct Rv32RdWriteAdapterAir {
101 pub(super) memory_bridge: MemoryBridge,
102 pub(super) execution_bridge: ExecutionBridge,
103}
104
105#[derive(Clone, Copy, Debug, derive_new::new)]
106pub struct Rv32CondRdWriteAdapterAir {
107 inner: Rv32RdWriteAdapterAir,
108}
109
110impl<F: Field> BaseAir<F> for Rv32RdWriteAdapterAir {
111 fn width(&self) -> usize {
112 Rv32RdWriteAdapterCols::<F>::width()
113 }
114}
115
116impl<F: Field> BaseAir<F> for Rv32CondRdWriteAdapterAir {
117 fn width(&self) -> usize {
118 Rv32CondRdWriteAdapterCols::<F>::width()
119 }
120}
121
122impl Rv32RdWriteAdapterAir {
123 #[allow(clippy::type_complexity)]
132 fn conditional_eval<AB: InteractionBuilder>(
133 &self,
134 builder: &mut AB,
135 local_cols: &Rv32RdWriteAdapterCols<AB::Var>,
136 ctx: AdapterAirContext<
137 AB::Expr,
138 BasicAdapterInterface<
139 AB::Expr,
140 ImmInstruction<AB::Expr>,
141 0,
142 1,
143 0,
144 RV32_REGISTER_NUM_LIMBS,
145 >,
146 >,
147 needs_write: Option<AB::Expr>,
148 ) {
149 let timestamp: AB::Var = local_cols.from_state.timestamp;
150 let timestamp_delta = 1;
151 let (write_count, f) = if let Some(needs_write) = needs_write {
152 (needs_write.clone(), needs_write)
153 } else {
154 (ctx.instruction.is_valid.clone(), AB::Expr::ZERO)
155 };
156 self.memory_bridge
157 .write(
158 MemoryAddress::new(
159 AB::F::from_canonical_u32(RV32_REGISTER_AS),
160 local_cols.rd_ptr,
161 ),
162 ctx.writes[0].clone(),
163 timestamp,
164 &local_cols.rd_aux_cols,
165 )
166 .eval(builder, write_count);
167
168 let to_pc = ctx
169 .to_pc
170 .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
171 self.execution_bridge
173 .execute(
174 ctx.instruction.opcode,
175 [
176 local_cols.rd_ptr.into(),
177 AB::Expr::ZERO,
178 ctx.instruction.immediate,
179 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
180 AB::Expr::ZERO,
181 f,
182 ],
183 local_cols.from_state,
184 ExecutionState {
185 pc: to_pc,
186 timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
187 },
188 )
189 .eval(builder, ctx.instruction.is_valid);
190 }
191}
192
193impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32RdWriteAdapterAir {
194 type Interface =
195 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
196
197 fn eval(
198 &self,
199 builder: &mut AB,
200 local: &[AB::Var],
201 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
202 ) {
203 let local_cols: &Rv32RdWriteAdapterCols<AB::Var> = (*local).borrow();
204 self.conditional_eval(builder, local_cols, ctx, None);
205 }
206
207 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
208 let cols: &Rv32RdWriteAdapterCols<_> = local.borrow();
209 cols.from_state.pc
210 }
211}
212
213impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32CondRdWriteAdapterAir {
214 type Interface =
215 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
216
217 fn eval(
218 &self,
219 builder: &mut AB,
220 local: &[AB::Var],
221 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
222 ) {
223 let local_cols: &Rv32CondRdWriteAdapterCols<AB::Var> = (*local).borrow();
224
225 builder.assert_bool(local_cols.needs_write);
226 builder
227 .when::<AB::Expr>(not(ctx.instruction.is_valid.clone()))
228 .assert_zero(local_cols.needs_write);
229
230 self.inner.conditional_eval(
231 builder,
232 &local_cols.inner,
233 ctx,
234 Some(local_cols.needs_write.into()),
235 );
236 }
237
238 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
239 let cols: &Rv32CondRdWriteAdapterCols<_> = local.borrow();
240 cols.inner.from_state.pc
241 }
242}
243
244impl<F: PrimeField32> VmAdapterChip<F> for Rv32RdWriteAdapterChip<F> {
245 type ReadRecord = ();
246 type WriteRecord = Rv32RdWriteWriteRecord;
247 type Air = Rv32RdWriteAdapterAir;
248 type Interface = BasicAdapterInterface<F, ImmInstruction<F>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
249
250 fn preprocess(
251 &mut self,
252 _memory: &mut MemoryController<F>,
253 instruction: &Instruction<F>,
254 ) -> Result<(
255 <Self::Interface as VmAdapterInterface<F>>::Reads,
256 Self::ReadRecord,
257 )> {
258 let d = instruction.d;
259 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
260
261 Ok(([], ()))
262 }
263
264 fn postprocess(
265 &mut self,
266 memory: &mut MemoryController<F>,
267 instruction: &Instruction<F>,
268 from_state: ExecutionState<u32>,
269 output: AdapterRuntimeContext<F, Self::Interface>,
270 _read_record: &Self::ReadRecord,
271 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
272 let Instruction { a, d, .. } = *instruction;
273 let (rd_id, _) = memory.write(d, a, output.writes[0]);
274
275 Ok((
276 ExecutionState {
277 pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
278 timestamp: memory.timestamp(),
279 },
280 Self::WriteRecord {
281 from_state,
282 rd_id: Some(rd_id),
283 },
284 ))
285 }
286
287 fn generate_trace_row(
288 &self,
289 row_slice: &mut [F],
290 _read_record: Self::ReadRecord,
291 write_record: Self::WriteRecord,
292 memory: &OfflineMemory<F>,
293 ) {
294 let aux_cols_factory = memory.aux_cols_factory();
295 let adapter_cols: &mut Rv32RdWriteAdapterCols<F> = row_slice.borrow_mut();
296 adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
297 let rd = memory.record_by_id(write_record.rd_id.unwrap());
298 adapter_cols.rd_ptr = rd.pointer;
299 aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols);
300 }
301
302 fn air(&self) -> &Self::Air {
303 &self.air
304 }
305}
306
307impl<F: PrimeField32> VmAdapterChip<F> for Rv32CondRdWriteAdapterChip<F> {
308 type ReadRecord = ();
309 type WriteRecord = Rv32RdWriteWriteRecord;
310 type Air = Rv32CondRdWriteAdapterAir;
311 type Interface = BasicAdapterInterface<F, ImmInstruction<F>, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>;
312
313 fn preprocess(
314 &mut self,
315 memory: &mut MemoryController<F>,
316 instruction: &Instruction<F>,
317 ) -> Result<(
318 <Self::Interface as VmAdapterInterface<F>>::Reads,
319 Self::ReadRecord,
320 )> {
321 self.inner.preprocess(memory, instruction)
322 }
323
324 fn postprocess(
325 &mut self,
326 memory: &mut MemoryController<F>,
327 instruction: &Instruction<F>,
328 from_state: ExecutionState<u32>,
329 output: AdapterRuntimeContext<F, Self::Interface>,
330 _read_record: &Self::ReadRecord,
331 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
332 let Instruction { a, d, .. } = *instruction;
333 let rd_id = if instruction.f != F::ZERO {
334 let (rd_id, _) = memory.write(d, a, output.writes[0]);
335 Some(rd_id)
336 } else {
337 memory.increment_timestamp();
338 None
339 };
340
341 Ok((
342 ExecutionState {
343 pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
344 timestamp: memory.timestamp(),
345 },
346 Self::WriteRecord { from_state, rd_id },
347 ))
348 }
349
350 fn generate_trace_row(
351 &self,
352 row_slice: &mut [F],
353 _read_record: Self::ReadRecord,
354 write_record: Self::WriteRecord,
355 memory: &OfflineMemory<F>,
356 ) {
357 let aux_cols_factory = memory.aux_cols_factory();
358 let adapter_cols: &mut Rv32CondRdWriteAdapterCols<F> = row_slice.borrow_mut();
359 adapter_cols.inner.from_state = write_record.from_state.map(F::from_canonical_u32);
360 if let Some(rd_id) = write_record.rd_id {
361 let rd = memory.record_by_id(rd_id);
362 adapter_cols.inner.rd_ptr = rd.pointer;
363 aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.inner.rd_aux_cols);
364 adapter_cols.needs_write = F::ONE;
365 }
366 }
367
368 fn air(&self) -> &Self::Air {
369 &self.air
370 }
371}