1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4};
5
6use openvm_circuit::{
7 arch::{
8 AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
9 ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip,
10 VmAdapterInterface,
11 },
12 system::{
13 memory::{
14 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
15 MemoryAddress, MemoryController, OfflineMemory, RecordId,
16 },
17 program::ProgramBus,
18 },
19};
20use openvm_circuit_primitives::{
21 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
22 utils::not,
23};
24use openvm_circuit_primitives_derive::AlignedBorrow;
25use openvm_instructions::{
26 instruction::Instruction,
27 program::DEFAULT_PC_STEP,
28 riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
29};
30use openvm_stark_backend::{
31 interaction::InteractionBuilder,
32 p3_air::{AirBuilder, BaseAir},
33 p3_field::{Field, FieldAlgebra, PrimeField32},
34};
35use serde::{Deserialize, Serialize};
36
37use super::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
38
39pub struct Rv32BaseAluAdapterChip<F: Field> {
43 pub air: Rv32BaseAluAdapterAir,
44 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
45 _marker: PhantomData<F>,
46}
47
48impl<F: PrimeField32> Rv32BaseAluAdapterChip<F> {
49 pub fn new(
50 execution_bus: ExecutionBus,
51 program_bus: ProgramBus,
52 memory_bridge: MemoryBridge,
53 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
54 ) -> Self {
55 Self {
56 air: Rv32BaseAluAdapterAir {
57 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
58 memory_bridge,
59 bitwise_lookup_bus: bitwise_lookup_chip.bus(),
60 },
61 bitwise_lookup_chip,
62 _marker: PhantomData,
63 }
64 }
65}
66
67#[repr(C)]
68#[derive(Clone, Debug, Serialize, Deserialize)]
69#[serde(bound = "F: Field")]
70pub struct Rv32BaseAluReadRecord<F: Field> {
71 pub rs1: RecordId,
73 pub rs2: Option<RecordId>,
77 pub rs2_imm: F,
79}
80
81#[repr(C)]
82#[derive(Clone, Debug, Serialize, Deserialize)]
83#[serde(bound = "F: Field")]
84pub struct Rv32BaseAluWriteRecord<F: Field> {
85 pub from_state: ExecutionState<u32>,
86 pub rd: (RecordId, [F; 4]),
88}
89
90#[repr(C)]
91#[derive(AlignedBorrow)]
92pub struct Rv32BaseAluAdapterCols<T> {
93 pub from_state: ExecutionState<T>,
94 pub rd_ptr: T,
95 pub rs1_ptr: T,
96 pub rs2: T,
98 pub rs2_as: T,
100 pub reads_aux: [MemoryReadAuxCols<T>; 2],
101 pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
102}
103
104#[allow(dead_code)]
105#[derive(Clone, Copy, Debug, derive_new::new)]
106pub struct Rv32BaseAluAdapterAir {
107 pub(super) execution_bridge: ExecutionBridge,
108 pub(super) memory_bridge: MemoryBridge,
109 bitwise_lookup_bus: BitwiseOperationLookupBus,
110}
111
112impl<F: Field> BaseAir<F> for Rv32BaseAluAdapterAir {
113 fn width(&self) -> usize {
114 Rv32BaseAluAdapterCols::<F>::width()
115 }
116}
117
118impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32BaseAluAdapterAir {
119 type Interface = BasicAdapterInterface<
120 AB::Expr,
121 MinimalInstruction<AB::Expr>,
122 2,
123 1,
124 RV32_REGISTER_NUM_LIMBS,
125 RV32_REGISTER_NUM_LIMBS,
126 >;
127
128 fn eval(
129 &self,
130 builder: &mut AB,
131 local: &[AB::Var],
132 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
133 ) {
134 let local: &Rv32BaseAluAdapterCols<_> = local.borrow();
135 let timestamp = local.from_state.timestamp;
136 let mut timestamp_delta: usize = 0;
137 let mut timestamp_pp = || {
138 timestamp_delta += 1;
139 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
140 };
141
142 let rs2_limbs = ctx.reads[1].clone();
146 let rs2_sign = rs2_limbs[2].clone();
147 let rs2_imm = rs2_limbs[0].clone()
148 + rs2_limbs[1].clone() * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS)
149 + rs2_sign.clone() * AB::Expr::from_canonical_usize(1 << (2 * RV32_CELL_BITS));
150 builder.assert_bool(local.rs2_as);
151 let mut rs2_imm_when = builder.when(not(local.rs2_as));
152 rs2_imm_when.assert_eq(local.rs2, rs2_imm);
153 rs2_imm_when.assert_eq(rs2_sign.clone(), rs2_limbs[3].clone());
154 rs2_imm_when.assert_zero(
155 rs2_sign.clone()
156 * (AB::Expr::from_canonical_usize((1 << RV32_CELL_BITS) - 1) - rs2_sign),
157 );
158 self.bitwise_lookup_bus
159 .send_range(rs2_limbs[0].clone(), rs2_limbs[1].clone())
160 .eval(builder, ctx.instruction.is_valid.clone() - local.rs2_as);
161
162 self.memory_bridge
163 .read(
164 MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rs1_ptr),
165 ctx.reads[0].clone(),
166 timestamp_pp(),
167 &local.reads_aux[0],
168 )
169 .eval(builder, ctx.instruction.is_valid.clone());
170
171 builder
173 .when(local.rs2_as)
174 .assert_one(ctx.instruction.is_valid.clone());
175 self.memory_bridge
176 .read(
177 MemoryAddress::new(local.rs2_as, local.rs2),
178 ctx.reads[1].clone(),
179 timestamp_pp(),
180 &local.reads_aux[1],
181 )
182 .eval(builder, local.rs2_as);
183
184 self.memory_bridge
185 .write(
186 MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), local.rd_ptr),
187 ctx.writes[0].clone(),
188 timestamp_pp(),
189 &local.writes_aux,
190 )
191 .eval(builder, ctx.instruction.is_valid.clone());
192
193 self.execution_bridge
194 .execute_and_increment_or_set_pc(
195 ctx.instruction.opcode,
196 [
197 local.rd_ptr.into(),
198 local.rs1_ptr.into(),
199 local.rs2.into(),
200 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
201 local.rs2_as.into(),
202 ],
203 local.from_state,
204 AB::F::from_canonical_usize(timestamp_delta),
205 (DEFAULT_PC_STEP, ctx.to_pc),
206 )
207 .eval(builder, ctx.instruction.is_valid);
208 }
209
210 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
211 let cols: &Rv32BaseAluAdapterCols<_> = local.borrow();
212 cols.from_state.pc
213 }
214}
215
216impl<F: PrimeField32> VmAdapterChip<F> for Rv32BaseAluAdapterChip<F> {
217 type ReadRecord = Rv32BaseAluReadRecord<F>;
218 type WriteRecord = Rv32BaseAluWriteRecord<F>;
219 type Air = Rv32BaseAluAdapterAir;
220 type Interface = BasicAdapterInterface<
221 F,
222 MinimalInstruction<F>,
223 2,
224 1,
225 RV32_REGISTER_NUM_LIMBS,
226 RV32_REGISTER_NUM_LIMBS,
227 >;
228
229 fn preprocess(
230 &mut self,
231 memory: &mut MemoryController<F>,
232 instruction: &Instruction<F>,
233 ) -> Result<(
234 <Self::Interface as VmAdapterInterface<F>>::Reads,
235 Self::ReadRecord,
236 )> {
237 let Instruction { b, c, d, e, .. } = *instruction;
238
239 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
240 debug_assert!(
241 e.as_canonical_u32() == RV32_IMM_AS || e.as_canonical_u32() == RV32_REGISTER_AS
242 );
243
244 let rs1 = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
245 let (rs2, rs2_data, rs2_imm) = if e.is_zero() {
246 let c_u32 = c.as_canonical_u32();
247 debug_assert_eq!(c_u32 >> 24, 0);
248 memory.increment_timestamp();
249 (
250 None,
251 [
252 c_u32 as u8,
253 (c_u32 >> 8) as u8,
254 (c_u32 >> 16) as u8,
255 (c_u32 >> 16) as u8,
256 ]
257 .map(F::from_canonical_u8),
258 c,
259 )
260 } else {
261 let rs2_read = memory.read::<RV32_REGISTER_NUM_LIMBS>(e, c);
262 (Some(rs2_read.0), rs2_read.1, F::ZERO)
263 };
264
265 Ok((
266 [rs1.1, rs2_data],
267 Self::ReadRecord {
268 rs1: rs1.0,
269 rs2,
270 rs2_imm,
271 },
272 ))
273 }
274
275 fn postprocess(
276 &mut self,
277 memory: &mut MemoryController<F>,
278 instruction: &Instruction<F>,
279 from_state: ExecutionState<u32>,
280 output: AdapterRuntimeContext<F, Self::Interface>,
281 _read_record: &Self::ReadRecord,
282 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
283 let Instruction { a, d, .. } = instruction;
284 let rd = memory.write(*d, *a, output.writes[0]);
285
286 let timestamp_delta = memory.timestamp() - from_state.timestamp;
287 debug_assert!(
288 timestamp_delta == 3,
289 "timestamp delta is {}, expected 3",
290 timestamp_delta
291 );
292
293 Ok((
294 ExecutionState {
295 pc: from_state.pc + DEFAULT_PC_STEP,
296 timestamp: memory.timestamp(),
297 },
298 Self::WriteRecord { from_state, rd },
299 ))
300 }
301
302 fn generate_trace_row(
303 &self,
304 row_slice: &mut [F],
305 read_record: Self::ReadRecord,
306 write_record: Self::WriteRecord,
307 memory: &OfflineMemory<F>,
308 ) {
309 let row_slice: &mut Rv32BaseAluAdapterCols<_> = row_slice.borrow_mut();
310 let aux_cols_factory = memory.aux_cols_factory();
311
312 let rd = memory.record_by_id(write_record.rd.0);
313 row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
314 row_slice.rd_ptr = rd.pointer;
315
316 let rs1 = memory.record_by_id(read_record.rs1);
317 let rs2 = read_record.rs2.map(|rs2| memory.record_by_id(rs2));
318 row_slice.rs1_ptr = rs1.pointer;
319
320 if let Some(rs2) = rs2 {
321 row_slice.rs2 = rs2.pointer;
322 row_slice.rs2_as = rs2.address_space;
323 aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]);
324 aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]);
325 } else {
326 row_slice.rs2 = read_record.rs2_imm;
327 row_slice.rs2_as = F::ZERO;
328 let rs2_imm = row_slice.rs2.as_canonical_u32();
329 let mask = (1 << RV32_CELL_BITS) - 1;
330 self.bitwise_lookup_chip
331 .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask);
332 aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]);
333 }
335 aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux);
336 }
337
338 fn air(&self) -> &Self::Air {
339 &self.air
340 }
341}