openvm_rv32im_circuit/base_alu/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    utils::not,
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{instruction::Instruction, LocalOpcode};
16use openvm_rv32im_transpiler::BaseAluOpcode;
17use openvm_stark_backend::{
18    interaction::InteractionBuilder,
19    p3_air::{AirBuilder, BaseAir},
20    p3_field::{Field, FieldAlgebra, PrimeField32},
21    rap::BaseAirWithPublicValues,
22};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use serde_big_array::BigArray;
25use strum::IntoEnumIterator;
26
27#[repr(C)]
28#[derive(AlignedBorrow)]
29pub struct BaseAluCoreCols<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
30    pub a: [T; NUM_LIMBS],
31    pub b: [T; NUM_LIMBS],
32    pub c: [T; NUM_LIMBS],
33
34    pub opcode_add_flag: T,
35    pub opcode_sub_flag: T,
36    pub opcode_xor_flag: T,
37    pub opcode_or_flag: T,
38    pub opcode_and_flag: T,
39}
40
41#[derive(Copy, Clone, Debug)]
42pub struct BaseAluCoreAir<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
43    pub bus: BitwiseOperationLookupBus,
44    offset: usize,
45}
46
47impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAir<F>
48    for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
49{
50    fn width(&self) -> usize {
51        BaseAluCoreCols::<F, NUM_LIMBS, LIMB_BITS>::width()
52    }
53}
54impl<F: Field, const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
55    for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
56{
57}
58
59impl<AB, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
60    for BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>
61where
62    AB: InteractionBuilder,
63    I: VmAdapterInterface<AB::Expr>,
64    I::Reads: From<[[AB::Expr; NUM_LIMBS]; 2]>,
65    I::Writes: From<[[AB::Expr; NUM_LIMBS]; 1]>,
66    I::ProcessedInstruction: From<MinimalInstruction<AB::Expr>>,
67{
68    fn eval(
69        &self,
70        builder: &mut AB,
71        local_core: &[AB::Var],
72        _from_pc: AB::Var,
73    ) -> AdapterAirContext<AB::Expr, I> {
74        let cols: &BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = local_core.borrow();
75        let flags = [
76            cols.opcode_add_flag,
77            cols.opcode_sub_flag,
78            cols.opcode_xor_flag,
79            cols.opcode_or_flag,
80            cols.opcode_and_flag,
81        ];
82
83        let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
84            builder.assert_bool(flag);
85            acc + flag.into()
86        });
87        builder.assert_bool(is_valid.clone());
88
89        let a = &cols.a;
90        let b = &cols.b;
91        let c = &cols.c;
92
93        // For ADD, define carry[i] = (b[i] + c[i] + carry[i - 1] - a[i]) / 2^LIMB_BITS. If
94        // each carry[i] is boolean and 0 <= a[i] < 2^LIMB_BITS, it can be proven that
95        // a[i] = (b[i] + c[i]) % 2^LIMB_BITS as necessary. The same holds for SUB when
96        // carry[i] is (a[i] + c[i] - b[i] + carry[i - 1]) / 2^LIMB_BITS.
97        let mut carry_add: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
98        let mut carry_sub: [AB::Expr; NUM_LIMBS] = array::from_fn(|_| AB::Expr::ZERO);
99        let carry_divide = AB::F::from_canonical_usize(1 << LIMB_BITS).inverse();
100
101        for i in 0..NUM_LIMBS {
102            // We explicitly separate the constraints for ADD and SUB in order to keep degree
103            // cubic. Because we constrain that the carry (which is arbitrary) is bool, if
104            // carry has degree larger than 1 the max-degree constrain could be at least 4.
105            carry_add[i] = AB::Expr::from(carry_divide)
106                * (b[i] + c[i] - a[i]
107                    + if i > 0 {
108                        carry_add[i - 1].clone()
109                    } else {
110                        AB::Expr::ZERO
111                    });
112            builder
113                .when(cols.opcode_add_flag)
114                .assert_bool(carry_add[i].clone());
115            carry_sub[i] = AB::Expr::from(carry_divide)
116                * (a[i] + c[i] - b[i]
117                    + if i > 0 {
118                        carry_sub[i - 1].clone()
119                    } else {
120                        AB::Expr::ZERO
121                    });
122            builder
123                .when(cols.opcode_sub_flag)
124                .assert_bool(carry_sub[i].clone());
125        }
126
127        // Interaction with BitwiseOperationLookup to range check a for ADD and SUB, and
128        // constrain a's correctness for XOR, OR, and AND.
129        let bitwise = cols.opcode_xor_flag + cols.opcode_or_flag + cols.opcode_and_flag;
130        for i in 0..NUM_LIMBS {
131            let x = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * b[i];
132            let y = not::<AB::Expr>(bitwise.clone()) * a[i] + bitwise.clone() * c[i];
133            let x_xor_y = cols.opcode_xor_flag * a[i]
134                + cols.opcode_or_flag * ((AB::Expr::from_canonical_u32(2) * a[i]) - b[i] - c[i])
135                + cols.opcode_and_flag * (b[i] + c[i] - (AB::Expr::from_canonical_u32(2) * a[i]));
136            self.bus
137                .send_xor(x, y, x_xor_y)
138                .eval(builder, is_valid.clone());
139        }
140
141        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
142            self,
143            flags.iter().zip(BaseAluOpcode::iter()).fold(
144                AB::Expr::ZERO,
145                |acc, (flag, local_opcode)| {
146                    acc + (*flag).into() * AB::Expr::from_canonical_u8(local_opcode as u8)
147                },
148            ),
149        );
150
151        AdapterAirContext {
152            to_pc: None,
153            reads: [cols.b.map(Into::into), cols.c.map(Into::into)].into(),
154            writes: [cols.a.map(Into::into)].into(),
155            instruction: MinimalInstruction {
156                is_valid,
157                opcode: expected_opcode,
158            }
159            .into(),
160        }
161    }
162
163    fn start_offset(&self) -> usize {
164        self.offset
165    }
166}
167
168#[repr(C)]
169#[derive(Clone, Debug, Serialize, Deserialize)]
170#[serde(bound = "T: Serialize + DeserializeOwned")]
171pub struct BaseAluCoreRecord<T, const NUM_LIMBS: usize, const LIMB_BITS: usize> {
172    pub opcode: BaseAluOpcode,
173    #[serde(with = "BigArray")]
174    pub a: [T; NUM_LIMBS],
175    #[serde(with = "BigArray")]
176    pub b: [T; NUM_LIMBS],
177    #[serde(with = "BigArray")]
178    pub c: [T; NUM_LIMBS],
179}
180
181pub struct BaseAluCoreChip<const NUM_LIMBS: usize, const LIMB_BITS: usize> {
182    pub air: BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>,
183    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
184}
185
186impl<const NUM_LIMBS: usize, const LIMB_BITS: usize> BaseAluCoreChip<NUM_LIMBS, LIMB_BITS> {
187    pub fn new(
188        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<LIMB_BITS>,
189        offset: usize,
190    ) -> Self {
191        Self {
192            air: BaseAluCoreAir {
193                bus: bitwise_lookup_chip.bus(),
194                offset,
195            },
196            bitwise_lookup_chip,
197        }
198    }
199}
200
201impl<F, I, const NUM_LIMBS: usize, const LIMB_BITS: usize> VmCoreChip<F, I>
202    for BaseAluCoreChip<NUM_LIMBS, LIMB_BITS>
203where
204    F: PrimeField32,
205    I: VmAdapterInterface<F>,
206    I::Reads: Into<[[F; NUM_LIMBS]; 2]>,
207    I::Writes: From<[[F; NUM_LIMBS]; 1]>,
208{
209    type Record = BaseAluCoreRecord<F, NUM_LIMBS, LIMB_BITS>;
210    type Air = BaseAluCoreAir<NUM_LIMBS, LIMB_BITS>;
211
212    #[allow(clippy::type_complexity)]
213    fn execute_instruction(
214        &self,
215        instruction: &Instruction<F>,
216        _from_pc: u32,
217        reads: I::Reads,
218    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
219        let Instruction { opcode, .. } = instruction;
220        let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
221
222        let data: [[F; NUM_LIMBS]; 2] = reads.into();
223        let b = data[0].map(|x| x.as_canonical_u32());
224        let c = data[1].map(|y| y.as_canonical_u32());
225        let a = run_alu::<NUM_LIMBS, LIMB_BITS>(local_opcode, &b, &c);
226
227        let output = AdapterRuntimeContext {
228            to_pc: None,
229            writes: [a.map(F::from_canonical_u32)].into(),
230        };
231
232        if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB {
233            for a_val in a {
234                self.bitwise_lookup_chip.request_xor(a_val, a_val);
235            }
236        } else {
237            for (b_val, c_val) in b.iter().zip(c.iter()) {
238                self.bitwise_lookup_chip.request_xor(*b_val, *c_val);
239            }
240        }
241
242        let record = Self::Record {
243            opcode: local_opcode,
244            a: a.map(F::from_canonical_u32),
245            b: data[0],
246            c: data[1],
247        };
248
249        Ok((output, record))
250    }
251
252    fn get_opcode_name(&self, opcode: usize) -> String {
253        format!("{:?}", BaseAluOpcode::from_usize(opcode - self.air.offset))
254    }
255
256    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
257        let row_slice: &mut BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut();
258        row_slice.a = record.a;
259        row_slice.b = record.b;
260        row_slice.c = record.c;
261        row_slice.opcode_add_flag = F::from_bool(record.opcode == BaseAluOpcode::ADD);
262        row_slice.opcode_sub_flag = F::from_bool(record.opcode == BaseAluOpcode::SUB);
263        row_slice.opcode_xor_flag = F::from_bool(record.opcode == BaseAluOpcode::XOR);
264        row_slice.opcode_or_flag = F::from_bool(record.opcode == BaseAluOpcode::OR);
265        row_slice.opcode_and_flag = F::from_bool(record.opcode == BaseAluOpcode::AND);
266    }
267
268    fn air(&self) -> &Self::Air {
269        &self.air
270    }
271}
272
273pub(super) fn run_alu<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
274    opcode: BaseAluOpcode,
275    x: &[u32; NUM_LIMBS],
276    y: &[u32; NUM_LIMBS],
277) -> [u32; NUM_LIMBS] {
278    match opcode {
279        BaseAluOpcode::ADD => run_add::<NUM_LIMBS, LIMB_BITS>(x, y),
280        BaseAluOpcode::SUB => run_subtract::<NUM_LIMBS, LIMB_BITS>(x, y),
281        BaseAluOpcode::XOR => run_xor::<NUM_LIMBS, LIMB_BITS>(x, y),
282        BaseAluOpcode::OR => run_or::<NUM_LIMBS, LIMB_BITS>(x, y),
283        BaseAluOpcode::AND => run_and::<NUM_LIMBS, LIMB_BITS>(x, y),
284    }
285}
286
287fn run_add<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
288    x: &[u32; NUM_LIMBS],
289    y: &[u32; NUM_LIMBS],
290) -> [u32; NUM_LIMBS] {
291    let mut z = [0u32; NUM_LIMBS];
292    let mut carry = [0u32; NUM_LIMBS];
293    for i in 0..NUM_LIMBS {
294        z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 };
295        carry[i] = z[i] >> LIMB_BITS;
296        z[i] &= (1 << LIMB_BITS) - 1;
297    }
298    z
299}
300
301fn run_subtract<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
302    x: &[u32; NUM_LIMBS],
303    y: &[u32; NUM_LIMBS],
304) -> [u32; NUM_LIMBS] {
305    let mut z = [0u32; NUM_LIMBS];
306    let mut carry = [0u32; NUM_LIMBS];
307    for i in 0..NUM_LIMBS {
308        let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 };
309        if x[i] >= rhs {
310            z[i] = x[i] - rhs;
311            carry[i] = 0;
312        } else {
313            z[i] = x[i] + (1 << LIMB_BITS) - rhs;
314            carry[i] = 1;
315        }
316    }
317    z
318}
319
320fn run_xor<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
321    x: &[u32; NUM_LIMBS],
322    y: &[u32; NUM_LIMBS],
323) -> [u32; NUM_LIMBS] {
324    array::from_fn(|i| x[i] ^ y[i])
325}
326
327fn run_or<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
328    x: &[u32; NUM_LIMBS],
329    y: &[u32; NUM_LIMBS],
330) -> [u32; NUM_LIMBS] {
331    array::from_fn(|i| x[i] | y[i])
332}
333
334fn run_and<const NUM_LIMBS: usize, const LIMB_BITS: usize>(
335    x: &[u32; NUM_LIMBS],
336    y: &[u32; NUM_LIMBS],
337) -> [u32; NUM_LIMBS] {
338    array::from_fn(|i| x[i] & y[i])
339}