openvm_rv32im_circuit/base_alu/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4    iter::zip,
5};
6
7use openvm_circuit::{
8    arch::*,
9    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
10};
11use openvm_circuit_primitives::{
12    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
13    utils::not,
14    AlignedBytesBorrow,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
18use openvm_rv32im_transpiler::BaseAluOpcode;
19use openvm_stark_backend::{
20    interaction::InteractionBuilder,
21    p3_air::{AirBuilder, BaseAir},
22    p3_field::{Field, FieldAlgebra, PrimeField32},
23    rap::BaseAirWithPublicValues,
24};
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow, Debug)]
29pub struct BaseAluCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30    pub a: [T; NUM_LIMBS],
31    pub b: [T; NUM_LIMBS],
32    pub c: [T; NUM_LIMBS],
33
34    pub opcode_add_flag: T,
35    pub opcode_sub_flag: T,
36    pub opcode_xor_flag: T,
37    pub opcode_or_flag: T,
38    pub opcode_and_flag: T,
39}
40
41#[derive(Copy, Clone, Debug, derive_new::new)]
42pub struct BaseAluCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
43    pub bus: BitwiseOperationLookupBus,
44    pub offset: usize,
45}
46
47impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
48    for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
49{
50    fn width(&self) -> usize {
51        BaseAluCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
52    }
53}
54impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
55    for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
56{
57}
58
59impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
60    for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
61where
62    AB: InteractionBuilder,
63    I: VmAdapterInterface<AB::Expr>,
64    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
65    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
66    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
67{
68    fn eval(
69        &self,
70        builder: &mut AB,
71        local_core: &[AB::Var],
72        _from_pc: AB::Var,
73    ) -> AdapterAirContext<AB::Expr, I> {
74        let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
75        let flags = [
76            cols.opcode_add_flag,
77            cols.opcode_sub_flag,
78            cols.opcode_xor_flag,
79            cols.opcode_or_flag,
80            cols.opcode_and_flag,
81        ];
82
83        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84            builder.assert_bool(flag);
85            acc + flag.into()
86        });
87        builder.assert_bool(is_valid.clone());
88
89        let a = &cols.a;
90        let b = &cols.b;
91        let c = &cols.c;
92
93        // For ADD, define carry[i] = (b[i] + c[i] + carry[i - 1] - a[i]) / 2^LIMB_BITS. If
94        // each carry[i] is boolean and 0 <= a[i] < 2^LIMB_BITS, it can be proven that
95        // a[i] = (b[i] + c[i]) % 2^LIMB_BITS as necessary. The same holds for SUB when
96        // carry[i] is (a[i] + c[i] - b[i] + carry[i - 1]) / 2^LIMB_BITS.
97        let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
98        let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
99        let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();
100
101        for i in 0..NUM_LIMBS {
102            // We explicitly separate the constraints for ADD and SUB in order to keep degree
103            // cubic. Because we constrain that the carry (which is arbitrary) is bool, if
104            // carry has degree larger than 1 the max-degree constrain could be at least 4.
105            carry_add[i] = AB::Expr::from(carry_divide)
106                * (b[i] + c[i] - a[i]
107                    + if i > 0 {
108                        carry_add[i - 1].clone()
109                    } else {
110                        AB::Expr::ZERO
111                    });
112            builder
113                .when(cols.opcode_add_flag)
114                .assert_bool(carry_add[i].clone());
115            carry_sub[i] = AB::Expr::from(carry_divide)
116                * (a[i] + c[i] - b[i]
117                    + if i > 0 {
118                        carry_sub[i - 1].clone()
119                    } else {
120                        AB::Expr::ZERO
121                    });
122            builder
123                .when(cols.opcode_sub_flag)
124                .assert_bool(carry_sub[i].clone());
125        }
126
127        // Interaction with BitwiseOperationLookup to range check a for ADD and SUB, and
128        // constrain a's correctness for XOR, OR, and AND.
129        let bitwise = cols.opcode_xor_flag + cols.opcode_or_flag + cols.opcode_and_flag;
130        for i in 0..NUM_LIMBS {
131            let x = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * b[i];
132            let y = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * c[i];
133            let x_xor_y = cols.opcode_xor_flag * a[i]
134                + cols.opcode_or_flag * ((AB::Expr::from_canonical_u32(2) * a[i]) - b[i] - c[i])
135                + cols.opcode_and_flag * (b[i] + c[i] - (AB::Expr::from_canonical_u32(2) * a[i]));
136            self.bus
137                .send_xor(x, y, x_xor_y)
138                .eval(builder, is_valid.clone());
139        }
140
141        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
142            self,
143            flags.iter().zip(BaseAluOpcode::iter()).fold(
144                AB::Expr::ZERO,
145                |acc, (flag, local_opcode)| {
146                    acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
147                },
148            ),
149        );
150
151        AdapterAirContext {
152            to_pc: None,
153            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
154            writes: [cols.a.map(Into::into)].into(),
155            instruction: MinimalInstruction {
156                is_valid,
157                opcode: expected_opcode,
158            }
159            .into(),
160        }
161    }
162
163    fn start_offset(&self) -> usize {
164        self.offset
165    }
166}
167
168#[repr(C, align(4))]
169#[derive(AlignedBytesBorrow, Debug)]
170pub struct BaseAluCoreRecord<const NUM_LIMBS: usize> {
171    pub b: [u8; NUM_LIMBS],
172    pub c: [u8; NUM_LIMBS],
173    // Use u8 instead of usize for better packing
174    pub local_opcode: u8,
175}
176
177#[derive(Clone, Copy, derive_new::new)]
178pub struct BaseAluExecutor<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
179    adapter: A,
180    pub offset: usize,
181}
182
183#[derive(derive_new::new)]
184pub struct BaseAluFiller<A, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
185    adapter: A,
186    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
187    pub offset: usize,
188}
189
190impl<F, A, RA, const NUM_LIMBS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
191    for BaseAluExecutor<A, NUM_LIMBS, LIMB_BITS>
192where
193    F: PrimeField32,
194    A: 'static
195        + AdapterTraceExecutor<
196            F,
197            ReadData: Into<[[u8; NUM_LIMBS]; 2]>,
198            WriteData: From<[[u8; NUM_LIMBS]; 1]>,
199        >,
200    for<'buf> RA: RecordArena<
201        'buf,
202        EmptyAdapterCoreLayout<F, A>,
203        (A::RecordMut<'buf>, &'buf mut BaseAluCoreRecord<NUM_LIMBS>),
204    >,
205{
206    fn get_opcode_name(&self, opcode: usize) -> String {
207        format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset))
208    }
209
210    fn execute(
211        &self,
212        state: VmStateMut<F, TracingMemory, RA>,
213        instruction: &Instruction<F>,
214    ) -> Result<(), ExecutionError> {
215        let Instruction { opcode, .. } = instruction;
216
217        let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset));
218        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
219
220        A::start(*state.pc, state.memory, &mut adapter_record);
221
222        [core_record.b, core_record.c] = self
223            .adapter
224            .read(state.memory, instruction, &mut adapter_record)
225            .into();
226
227        let rd = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &core_record.b, &core_record.c);
228
229        core_record.local_opcode = local_opcode as u8;
230
231        self.adapter
232            .write(state.memory, instruction, [rd].into(), &mut adapter_record);
233
234        *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
235
236        Ok(())
237    }
238}
239
240impl<F, A, const NUM_LIMBS: usize, const LIMB_BITS: usize> TraceFiller<F>
241    for BaseAluFiller<A, NUM_LIMBS, LIMB_BITS>
242where
243    F: PrimeField32,
244    A: 'static + AdapterTraceFiller<F>,
245{
246    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
247        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
248        // BaseAluCoreCols::width() elements
249        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
250        self.adapter.fill_trace_row(mem_helper, adapter_row);
251        // SAFETY: core_row contains a valid BaseAluCoreRecord written by the executor
252        // during trace generation
253        let record: &BaseAluCoreRecord<NUM_LIMBS> =
254            unsafe { get_record_from_slice(&mut core_row, ()) };
255        let core_row: &mut BaseAluCoreCols<F, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut();
256        // SAFETY: the following is highly unsafe. We are going to cast `core_row` to a record
257        // buffer, and then do an _overlapping_ write to the `core_row` as a row of field elements.
258        // This requires:
259        // - Cols and Record structs should be repr(C) and we write in reverse order (to ensure
260        //   non-overlapping)
261        // - Do not overwrite any reference in `record` before it has already been used or moved
262        // - alignment of `F` must be >= alignment of Record (AlignedBytesBorrow will panic
263        //   otherwise)
264
265        let local_opcode = BaseAluOpcode::from_usize(record.local_opcode as usize);
266        let a = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &record.b, &record.c);
267        // PERF: needless conversion
268        core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND);
269        core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR);
270        core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR);
271        core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB);
272        core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD);
273
274        if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB {
275            for a_val in a {
276                self.bitwise_lookup_chip
277                    .request_xor(a_val as u32, a_val as u32);
278            }
279        } else {
280            for (b_val, c_val) in zip(record.b, record.c) {
281                self.bitwise_lookup_chip
282                    .request_xor(b_val as u32, c_val as u32);
283            }
284        }
285        core_row.c = record.c.map(F::from_canonical_u8);
286        core_row.b = record.b.map(F::from_canonical_u8);
287        core_row.a = a.map(F::from_canonical_u8);
288    }
289}
290
291#[inline(always)]
292pub(super) fn run_alu<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
293    opcode: BaseAluOpcode,
294    x: &[u8; NUM_LIMBS],
295    y: &[u8; NUM_LIMBS],
296) -> [u8; NUM_LIMBS] {
297    debug_assert!(LIMB_BITS <= 8, "specialize for bytes");
298    match opcode {
299        BaseAluOpcode::ADD => run_add::<NUM_LIMBS, LIMB_BITS>(x, y),
300        BaseAluOpcode::SUB => run_subtract::<NUM_LIMBS, LIMB_BITS>(x, y),
301        BaseAluOpcode::XOR => run_xor::<NUM_LIMBS>(x, y),
302        BaseAluOpcode::OR => run_or::<NUM_LIMBS>(x, y),
303        BaseAluOpcode::AND => run_and::<NUM_LIMBS>(x, y),
304    }
305}
306
307#[inline(always)]
308fn run_add<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
309    x: &[u8; NUM_LIMBS],
310    y: &[u8; NUM_LIMBS],
311) -> [u8; NUM_LIMBS] {
312    let mut z = [0u8; NUM_LIMBS];
313    let mut carry = [0u8; NUM_LIMBS];
314    for i in 0..NUM_LIMBS {
315        let mut overflow =
316            (x[i] as u16) + (y[i] as u16) + if i > 0 { carry[i - 1] as u16 } else { 0 };
317        carry[i] = (overflow >> LIMB_BITS) as u8;
318        overflow &= (1u16 << LIMB_BITS) - 1;
319        z[i] = overflow as u8;
320    }
321    z
322}
323
324#[inline(always)]
325fn run_subtract<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
326    x: &[u8; NUM_LIMBS],
327    y: &[u8; NUM_LIMBS],
328) -> [u8; NUM_LIMBS] {
329    let mut z = [0u8; NUM_LIMBS];
330    let mut carry = [0u8; NUM_LIMBS];
331    for i in 0..NUM_LIMBS {
332        let rhs = y[i] as u16 + if i > 0 { carry[i - 1] as u16 } else { 0 };
333        if x[i] as u16 >= rhs {
334            z[i] = x[i] - rhs as u8;
335            carry[i] = 0;
336        } else {
337            z[i] = (x[i] as u16 + (1u16 << LIMB_BITS) - rhs) as u8;
338            carry[i] = 1;
339        }
340    }
341    z
342}
343
344#[inline(always)]
345fn run_xor<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
346    array::from_fn(|i| x[i] ^ y[i])
347}
348
349#[inline(always)]
350fn run_or<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
351    array::from_fn(|i| x[i] | y[i])
352}
353
354#[inline(always)]
355fn run_and<const NUM_LIMBS: usize>(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] {
356    array::from_fn(|i| x[i] & y[i])
357}