1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 utils::not,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, LocalOpcode};
16use openvm_rv32im_transpiler::BaseAluOpcode;
17use openvm_stark_backend::{
18 interaction::InteractionBuilder,
19 p3_air::{AirBuilder, BaseAir},
20 p3_field::{Field, FieldAlgebra, PrimeField32},
21 rap::BaseAirWithPublicValues,
22};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct BaseAluCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30 pub a: [T; NUM_LIMBS],
31 pub b: [T; NUM_LIMBS],
32 pub c: [T; NUM_LIMBS],
33
34 pub opcode_add_flag: T,
35 pub opcode_sub_flag: T,
36 pub opcode_xor_flag: T,
37 pub opcode_or_flag: T,
38 pub opcode_and_flag: T,
39}
40
41#[derive(Copy, Clone, Debug)]
42pub struct BaseAluCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
43 pub bus: BitwiseOperationLookupBus,
44 offset: usize,
45}
46
47impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
48 for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
49{
50 fn width(&self) -> usize {
51 BaseAluCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
52 }
53}
54impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
55 for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
56{
57}
58
59impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
60 for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
61where
62 AB: InteractionBuilder,
63 I: VmAdapterInterface<AB::Expr>,
64 I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
65 I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
66 I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
67{
68 fn eval(
69 &self,
70 builder: &mut AB,
71 local_core: &[AB::Var],
72 _from_pc: AB::Var,
73 ) -> AdapterAirContext<AB::Expr, I> {
74 let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
75 let flags = [
76 cols.opcode_add_flag,
77 cols.opcode_sub_flag,
78 cols.opcode_xor_flag,
79 cols.opcode_or_flag,
80 cols.opcode_and_flag,
81 ];
82
83 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84 builder.assert_bool(flag);
85 acc + flag.into()
86 });
87 builder.assert_bool(is_valid.clone());
88
89 let a = &cols.a;
90 let b = &cols.b;
91 let c = &cols.c;
92
93 let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
98 let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
99 let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();
100
101 for i in 0..NUM_LIMBS {
102 carry_add[i] = AB::Expr::from(carry_divide)
106 * (b[i] + c[i] - a[i]
107 + if i > 0 {
108 carry_add[i - 1].clone()
109 } else {
110 AB::Expr::ZERO
111 });
112 builder
113 .when(cols.opcode_add_flag)
114 .assert_bool(carry_add[i].clone());
115 carry_sub[i] = AB::Expr::from(carry_divide)
116 * (a[i] + c[i] - b[i]
117 + if i > 0 {
118 carry_sub[i - 1].clone()
119 } else {
120 AB::Expr::ZERO
121 });
122 builder
123 .when(cols.opcode_sub_flag)
124 .assert_bool(carry_sub[i].clone());
125 }
126
127 let bitwise = cols.opcode_xor_flag + cols.opcode_or_flag + cols.opcode_and_flag;
130 for i in 0..NUM_LIMBS {
131 let x = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * b[i];
132 let y = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * c[i];
133 let x_xor_y = cols.opcode_xor_flag * a[i]
134 + cols.opcode_or_flag * ((AB::Expr::from_canonical_u32(2) * a[i]) - b[i] - c[i])
135 + cols.opcode_and_flag * (b[i] + c[i] - (AB::Expr::from_canonical_u32(2) * a[i]));
136 self.bus
137 .send_xor(x, y, x_xor_y)
138 .eval(builder, is_valid.clone());
139 }
140
141 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
142 self,
143 flags.iter().zip(BaseAluOpcode::iter()).fold(
144 AB::Expr::ZERO,
145 |acc, (flag, local_opcode)| {
146 acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
147 },
148 ),
149 );
150
151 AdapterAirContext {
152 to_pc: None,
153 reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
154 writes: [cols.a.map(Into::into)].into(),
155 instruction: MinimalInstruction {
156 is_valid,
157 opcode: expected_opcode,
158 }
159 .into(),
160 }
161 }
162
163 fn start_offset(&self) -> usize {
164 self.offset
165 }
166}
167
168#[repr(C)]
169#[derive(Clone, Debug, Serialize, Deserialize)]
170#[serde(bound = "T: Serialize + DeserializeOwned")]
171pub struct BaseAluCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
172 pub opcode: BaseAluOpcode,
173 #[serde(with = "BigArray")]
174 pub a: [T; NUM_LIMBS],
175 #[serde(with = "BigArray")]
176 pub b: [T; NUM_LIMBS],
177 #[serde(with = "BigArray")]
178 pub c: [T; NUM_LIMBS],
179}
180
181pub struct BaseAluCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
182 pub air: BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>,
183 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
184}
185
186impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAluCoreChip<NUM_LIMBS, LIMB_BITS> {
187 pub fn new(
188 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
189 offset: usize,
190 ) -> Self {
191 Self {
192 air: BaseAluCoreAir {
193 bus: bitwise_lookup_chip.bus(),
194 offset,
195 },
196 bitwise_lookup_chip,
197 }
198 }
199}
200
201impl<F, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreChip<F, I>
202 for BaseAluCoreChip<NUM_LIMBS, LIMB_BITS>
203where
204 F: PrimeField32,
205 I: VmAdapterInterface<F>,
206 I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
207 I::Writes: From<[[F; NUM_LIMBS]; 1]>,
208{
209 type Record = BaseAluCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
210 type Air = BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>;
211
212 #[allow(clippy::type_complexity)]
213 fn execute_instruction(
214 &self,
215 instruction: &Instruction<F>,
216 _from_pc: u32,
217 reads: I::Reads,
218 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
219 let Instruction { opcode, .. } = instruction;
220 let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
221
222 let data: [[F; NUM_LIMBS]; 2] = reads.into();
223 let b = data[0].map(|x| x.as_canonical_u32());
224 let c = data[1].map(|y| y.as_canonical_u32());
225 let a = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &b, &c);
226
227 let output = AdapterRuntimeContext {
228 to_pc: None,
229 writes: [a.map(F::from_canonical_u32)].into(),
230 };
231
232 if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB {
233 for a_val in a {
234 self.bitwise_lookup_chip.request_xor(a_val, a_val);
235 }
236 } else {
237 for (b_val, c_val) in b.iter().zip(c.iter()) {
238 self.bitwise_lookup_chip.request_xor(*b_val, *c_val);
239 }
240 }
241
242 let record = Self::Record {
243 opcode: local_opcode,
244 a: a.map(F::from_canonical_u32),
245 b: data[0],
246 c: data[1],
247 };
248
249 Ok((output, record))
250 }
251
252 fn get_opcode_name(&self, opcode: usize) -> String {
253 format!("{:?}", BaseAluOpcode::from_usize(opcode - self.air.offset))
254 }
255
256 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
257 let row_slice: &mut BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
258 row_slice.a = record.a;
259 row_slice.b = record.b;
260 row_slice.c = record.c;
261 row_slice.opcode_add_flag = F::from_bool(record.opcode == BaseAluOpcode::ADD);
262 row_slice.opcode_sub_flag = F::from_bool(record.opcode == BaseAluOpcode::SUB);
263 row_slice.opcode_xor_flag = F::from_bool(record.opcode == BaseAluOpcode::XOR);
264 row_slice.opcode_or_flag = F::from_bool(record.opcode == BaseAluOpcode::OR);
265 row_slice.opcode_and_flag = F::from_bool(record.opcode == BaseAluOpcode::AND);
266 }
267
268 fn air(&self) -> &Self::Air {
269 &self.air
270 }
271}
272
273pub(super) fn run_alu<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
274 opcode: BaseAluOpcode,
275 x: &[u32; NUM_LIMBS],
276 y: &[u32; NUM_LIMBS],
277) -> [u32; NUM_LIMBS] {
278 match opcode {
279 BaseAluOpcode::ADD => run_add::<NUM_LIMBS, LIMB_BITS>(x, y),
280 BaseAluOpcode::SUB => run_subtract::<NUM_LIMBS, LIMB_BITS>(x, y),
281 BaseAluOpcode::XOR => run_xor::<NUM_LIMBS, LIMB_BITS>(x, y),
282 BaseAluOpcode::OR => run_or::<NUM_LIMBS, LIMB_BITS>(x, y),
283 BaseAluOpcode::AND => run_and::<NUM_LIMBS, LIMB_BITS>(x, y),
284 }
285}
286
287fn run_add<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
288 x: &[u32; NUM_LIMBS],
289 y: &[u32; NUM_LIMBS],
290) -> [u32; NUM_LIMBS] {
291 let mut z = [0u32; NUM_LIMBS];
292 let mut carry = [0u32; NUM_LIMBS];
293 for i in 0..NUM_LIMBS {
294 z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 };
295 carry[i] = z[i] >> LIMB_BITS;
296 z[i] &= (1 << LIMB_BITS) - 1;
297 }
298 z
299}
300
301fn run_subtract<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
302 x: &[u32; NUM_LIMBS],
303 y: &[u32; NUM_LIMBS],
304) -> [u32; NUM_LIMBS] {
305 let mut z = [0u32; NUM_LIMBS];
306 let mut carry = [0u32; NUM_LIMBS];
307 for i in 0..NUM_LIMBS {
308 let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 };
309 if x[i] >= rhs {
310 z[i] = x[i] - rhs;
311 carry[i] = 0;
312 } else {
313 z[i] = x[i] + (1 << LIMB_BITS) - rhs;
314 carry[i] = 1;
315 }
316 }
317 z
318}
319
320fn run_xor<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
321 x: &[u32; NUM_LIMBS],
322 y: &[u32; NUM_LIMBS],
323) -> [u32; NUM_LIMBS] {
324 array::from_fn(|i| x[i] ^ y[i])
325}
326
327fn run_or<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
328 x: &[u32; NUM_LIMBS],
329 y: &[u32; NUM_LIMBS],
330) -> [u32; NUM_LIMBS] {
331 array::from_fn(|i| x[i] | y[i])
332}
333
334fn run_and<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
335 x: &[u32; NUM_LIMBS],
336 y: &[u32; NUM_LIMBS],
337) -> [u32; NUM_LIMBS] {
338 array::from_fn(|i| x[i] & y[i])
339}