1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 iter::zip,
5};
6
7use openvm_circuit::{
8 arch::*,
9 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
10};
11use openvm_circuit_primitives::{
12 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
13 utils::not,
14 AlignedBytesBorrow,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
18use openvm_rv32im_transpiler::BaseAluOpcode;
19use openvm_stark_backend::{
20 interaction::InteractionBuilder,
21 p3_air::{AirBuilder, BaseAir},
22 p3_field::{Field, FieldAlgebra, PrimeField32},
23 rap::BaseAirWithPublicValues,
24};
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow, Debug)]
29pub struct BaseAluCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30 pub a: [T; NUM_LIMBS],
31 pub b: [T; NUM_LIMBS],
32 pub c: [T; NUM_LIMBS],
33
34 pub opcode_add_flag: T,
35 pub opcode_sub_flag: T,
36 pub opcode_xor_flag: T,
37 pub opcode_or_flag: T,
38 pub opcode_and_flag: T,
39}
40
41#[derive(Copy, Clone, Debug, derive_new::new)]
42pub struct BaseAluCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
43 pub bus: BitwiseOperationLookupBus,
44 pub offset: usize,
45}
46
47impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
48 for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
49{
50 fn width(&self) -> usize {
51 BaseAluCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
52 }
53}
54impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
55 for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
56{
57}
58
59impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
60 for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
61where
62 AB: InteractionBuilder,
63 I: VmAdapterInterface<AB::Expr>,
64 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
65 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
66 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
67{
68 fn eval(
69 &self,
70 builder: &mut AB,
71 local_core: &[AB::Var],
72 _from_pc: AB::Var,
73 ) -> AdapterAirContext<AB::Expr, I> {
74 let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
75 let flags = [
76 cols.opcode_add_flag,
77 cols.opcode_sub_flag,
78 cols.opcode_xor_flag,
79 cols.opcode_or_flag,
80 cols.opcode_and_flag,
81 ];
82
83 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84 builder.assert_bool(flag);
85 acc + flag.into()
86 });
87 builder.assert_bool(is_valid.clone());
88
89 let a = &cols.a;
90 let b = &cols.b;
91 let c = &cols.c;
92
93 let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
98 let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
99 let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();
100
101 for i in 0..NUM_LIMBS {
102 carry_add[i] = AB::Expr::from(carry_divide)
106 * (b[i] + c[i] - a[i]
107 + if i > 0 {
108 carry_add[i - 1].clone()
109 } else {
110 AB::Expr::ZERO
111 });
112 builder
113 .when(cols.opcode_add_flag)
114 .assert_bool(carry_add[i].clone());
115 carry_sub[i] = AB::Expr::from(carry_divide)
116 * (a[i] + c[i] - b[i]
117 + if i > 0 {
118 carry_sub[i - 1].clone()
119 } else {
120 AB::Expr::ZERO
121 });
122 builder
123 .when(cols.opcode_sub_flag)
124 .assert_bool(carry_sub[i].clone());
125 }
126
127 let bitwise = cols.opcode_xor_flag + cols.opcode_or_flag + cols.opcode_and_flag;
130 for i in 0..NUM_LIMBS {
131 let x = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * b[i];
132 let y = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * c[i];
133 let x_xor_y = cols.opcode_xor_flag * a[i]
134 + cols.opcode_or_flag * ((AB::Expr::from_canonical_u32(2) * a[i]) - b[i] - c[i])
135 + cols.opcode_and_flag * (b[i] + c[i] - (AB::Expr::from_canonical_u32(2) * a[i]));
136 self.bus
137 .send_xor(x, y, x_xor_y)
138 .eval(builder, is_valid.clone());
139 }
140
141 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
142 self,
143 flags.iter().zip(BaseAluOpcode::iter()).fold(
144 AB::Expr::ZERO,
145 |acc, (flag, local_opcode)| {
146 acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
147 },
148 ),
149 );
150
151 AdapterAirContext {
152 to_pc: None,
153 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
154 writes: [cols.a.map(Into::into)].into(),
155 instruction: MinimalInstruction {
156 is_valid,
157 opcode: expected_opcode,
158 }
159 .into(),
160 }
161 }
162
163 fn start_offset(&self) -> usize {
164 self.offset
165 }
166}
167
168#[repr(C, align(4))]
169#[derive(AlignedBytesBorrow, Debug)]
170pub struct BaseAluCoreRecord<const NUM_LIMBS: usize> {
171 pub b: [u8; NUM_LIMBS],
172 pub c: [u8; NUM_LIMBS],
173 pub local_opcode: u8,
175}
176
177#[derive(Clone, Copy, derive_new::new)]
178pub struct BaseAluExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
179 adapter: A,
180 pub offset: usize,
181}
182
183#[derive(derive_new::new)]
184pub struct BaseAluFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
185 adapter: A,
186 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
187 pub offset: usize,
188}
189
190impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
191 for BaseAluExecutor<A, NUM_LIMBS, LIMB_BITS>
192where
193 F: PrimeField32,
194 A: 'static
195 + AdapterTraceExecutor<
196 F,
197 ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
198 WriteData: From<[[u8; NUM_LIMBS]; 1]>,
199 >,
200 for<'buf> RA: RecordArena<
201 'buf,
202 EmptyAdapterCoreLayout<F, A>,
203 (A::RecordMut<'buf>, &'buf mut BaseAluCoreRecord<NUM_LIMBS>),
204 >,
205{
206 fn get_opcode_name(&self, opcode: usize) -> String {
207 format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset))
208 }
209
210 fn execute(
211 &self,
212 state: VmStateMut<F, TracingMemory, RA>,
213 instruction: &Instruction<F>,
214 ) -> Result<(), ExecutionError> {
215 let Instruction { opcode, .. } = instruction;
216
217 let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset));
218 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
219
220 A::start(*state.pc, state.memory, &mut adapter_record);
221
222 [core_record.b, core_record.c] = self
223 .adapter
224 .read(state.memory, instruction, &mut adapter_record)
225 .into();
226
227 let rd = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &core_record.b, &core_record.c);
228
229 core_record.local_opcode = local_opcode as u8;
230
231 self.adapter
232 .write(state.memory, instruction, [rd].into(), &mut adapter_record);
233
234 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
235
236 Ok(())
237 }
238}
239
240impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
241 for BaseAluFiller<A, NUM_LIMBS, LIMB_BITS>
242where
243 F: PrimeField32,
244 A: 'static + AdapterTraceFiller<F>,
245{
246 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
247 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
250 self.adapter.fill_trace_row(mem_helper, adapter_row);
251 let record: &BaseAluCoreRecord<NUM_LIMBS> =
254 unsafe { get_record_from_slice(&mut core_row, ()) };
255 let core_row: &mut BaseAluCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
256 let local_opcode = BaseAluOpcode::from_usize(record.local_opcode as usize);
266 let a = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &record.b, &record.c);
267 core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND);
269 core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR);
270 core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR);
271 core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB);
272 core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD);
273
274 if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB {
275 for a_val in a {
276 self.bitwise_lookup_chip
277 .request_xor(a_val as u32, a_val as u32);
278 }
279 } else {
280 for (b_val, c_val) in zip(record.b, record.c) {
281 self.bitwise_lookup_chip
282 .request_xor(b_val as u32, c_val as u32);
283 }
284 }
285 core_row.c = record.c.map(F::from_canonical_u8);
286 core_row.b = record.b.map(F::from_canonical_u8);
287 core_row.a = a.map(F::from_canonical_u8);
288 }
289}
290
291#[inline(always)]
292pub(super) fn run_alu<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
293 opcode: BaseAluOpcode,
294 x: &[u8; NUM_LIMBS],
295 y: &[u8; NUM_LIMBS],
296) -> [u8; NUM_LIMBS] {
297 debug_assert!(LIMB_BITS <= 8, "specialize for bytes");
298 match opcode {
299 BaseAluOpcode::ADD => run_add::<NUM_LIMBS, LIMB_BITS>(x, y),
300 BaseAluOpcode::SUB => run_subtract::<NUM_LIMBS, LIMB_BITS>(x, y),
301 BaseAluOpcode::XOR => run_xor::<NUM_LIMBS>(x, y),
302 BaseAluOpcode::OR => run_or::<NUM_LIMBS>(x, y),
303 BaseAluOpcode::AND => run_and::<NUM_LIMBS>(x, y),
304 }
305}
306
307#[inline(always)]
308fn run_add<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
309 x: &[u8; NUM_LIMBS],
310 y: &[u8; NUM_LIMBS],
311) -> [u8; NUM_LIMBS] {
312 let mut z = [0u8; NUM_LIMBS];
313 let mut carry = [0u8; NUM_LIMBS];
314 for i in 0..NUM_LIMBS {
315 let mut overflow =
316 (x[i] as u16) + (y[i] as u16) + if i > 0 { carry[i - 1] as u16 } else { 0 };
317 carry[i] = (overflow >> LIMB_BITS) as u8;
318 overflow &= (1u16 << LIMB_BITS) - 1;
319 z[i] = overflow as u8;
320 }
321 z
322}
323
324#[inline(always)]
325fn run_subtract<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
326 x: &[u8; NUM_LIMBS],
327 y: &[u8; NUM_LIMBS],
328) -> [u8; NUM_LIMBS] {
329 let mut z = [0u8; NUM_LIMBS];
330 let mut carry = [0u8; NUM_LIMBS];
331 for i in 0..NUM_LIMBS {
332 let rhs = y[i] as u16 + if i > 0 { carry[i - 1] as u16 } else { 0 };
333 if x[i] as u16 >= rhs {
334 z[i] = x[i] - rhs as u8;
335 carry[i] = 0;
336 } else {
337 z[i] = (x[i] as u16 + (1u16 << LIMB_BITS) - rhs) as u8;
338 carry[i] = 1;
339 }
340 }
341 z
342}
343
344#[inline(always)]
345fn run_xor<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
346 array::from_fn(|i| x[i] ^ y[i])
347}
348
349#[inline(always)]
350fn run_or<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
351 array::from_fn(|i| x[i] | y[i])
352}
353
354#[inline(always)]
355fn run_and<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
356 array::from_fn(|i| x[i] & y[i])
357}