openvm_circuit_primitives/bitwise_op_lookup/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4    sync::{atomic::AtomicU32, Arc},
5};
6
7use openvm_circuit_primitives_derive::AlignedBorrow;
8use openvm_stark_backend::{
9    config::{StarkGenericConfig, Val},
10    interaction::InteractionBuilder,
11    p3_air::{Air, BaseAir, PairBuilder},
12    p3_field::{Field, FieldAlgebra},
13    p3_matrix::{dense::RowMajorMatrix, Matrix},
14    prover::types::AirProofInput,
15    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
16    AirRef, Chip, ChipUsageGetter,
17};
18
19mod bus;
20#[cfg(test)]
21mod tests;
22
23pub use bus::*;
24
25#[derive(Default, AlignedBorrow, Copy, Clone)]
26#[repr(C)]
27pub struct BitwiseOperationLookupCols<T> {
28    /// Number of range check operations requested for each (x, y) pair
29    pub mult_range: T,
30    /// Number of XOR operations requested for each (x, y) pair
31    pub mult_xor: T,
32}
33
34#[derive(Default, AlignedBorrow, Copy, Clone)]
35#[repr(C)]
36pub struct BitwiseOperationLookupPreprocessedCols<T> {
37    pub x: T,
38    pub y: T,
39    /// XOR result of x and y (x ⊕ y)
40    pub z_xor: T,
41}
42
43pub const NUM_BITWISE_OP_LOOKUP_COLS: usize = size_of::<BitwiseOperationLookupCols<u8>>();
44pub const NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS: usize =
45    size_of::<BitwiseOperationLookupPreprocessedCols<u8>>();
46
47#[derive(Clone, Copy, Debug, derive_new::new)]
48pub struct BitwiseOperationLookupAir<const NUM_BITS: usize> {
49    pub bus: BitwiseOperationLookupBus,
50}
51
52impl<F: Field, const NUM_BITS: usize> BaseAirWithPublicValues<F>
53    for BitwiseOperationLookupAir<NUM_BITS>
54{
55}
56impl<F: Field, const NUM_BITS: usize> PartitionedBaseAir<F>
57    for BitwiseOperationLookupAir<NUM_BITS>
58{
59}
60impl<F: Field, const NUM_BITS: usize> BaseAir<F> for BitwiseOperationLookupAir<NUM_BITS> {
61    fn width(&self) -> usize {
62        NUM_BITWISE_OP_LOOKUP_COLS
63    }
64
65    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
66        let rows: Vec<F> = (0..(1 << NUM_BITS))
67            .flat_map(|x: u32| {
68                (0..(1 << NUM_BITS)).flat_map(move |y: u32| {
69                    [
70                        F::from_canonical_u32(x),
71                        F::from_canonical_u32(y),
72                        F::from_canonical_u32(x ^ y),
73                    ]
74                })
75            })
76            .collect();
77        Some(RowMajorMatrix::new(
78            rows,
79            NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS,
80        ))
81    }
82}
83
84impl<AB: InteractionBuilder + PairBuilder, const NUM_BITS: usize> Air<AB>
85    for BitwiseOperationLookupAir<NUM_BITS>
86{
87    fn eval(&self, builder: &mut AB) {
88        let preprocessed = builder.preprocessed();
89        let prep_local = preprocessed.row_slice(0);
90        let prep_local: &BitwiseOperationLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
91
92        let main = builder.main();
93        let local = main.row_slice(0);
94        let local: &BitwiseOperationLookupCols<AB::Var> = (*local).borrow();
95
96        self.bus
97            .receive(prep_local.x, prep_local.y, AB::F::ZERO, AB::F::ZERO)
98            .eval(builder, local.mult_range);
99        self.bus
100            .receive(prep_local.x, prep_local.y, prep_local.z_xor, AB::F::ONE)
101            .eval(builder, local.mult_xor);
102    }
103}
104
105// Lookup chip for operations on size NUM_BITS integers. Currently has pre-processed columns
106// for x ^ y and range check. Interactions are of form [x, y, z] where z is either x ^ y for
107// XOR or 0 for range check.
108
109pub struct BitwiseOperationLookupChip<const NUM_BITS: usize> {
110    pub air: BitwiseOperationLookupAir<NUM_BITS>,
111    pub count_range: Vec<AtomicU32>,
112    pub count_xor: Vec<AtomicU32>,
113}
114
115#[derive(Clone)]
116pub struct SharedBitwiseOperationLookupChip<const NUM_BITS: usize>(
117    Arc<BitwiseOperationLookupChip<NUM_BITS>>,
118);
119
120impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
121    pub fn new(bus: BitwiseOperationLookupBus) -> Self {
122        let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS);
123        let count_range = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
124        let count_xor = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
125        Self {
126            air: BitwiseOperationLookupAir::new(bus),
127            count_range,
128            count_xor,
129        }
130    }
131
132    pub fn bus(&self) -> BitwiseOperationLookupBus {
133        self.air.bus
134    }
135
136    pub fn air_width(&self) -> usize {
137        NUM_BITWISE_OP_LOOKUP_COLS
138    }
139
140    pub fn request_range(&self, x: u32, y: u32) {
141        let upper_bound = 1 << NUM_BITS;
142        debug_assert!(x < upper_bound, "x out of range: {} >= {}", x, upper_bound);
143        debug_assert!(y < upper_bound, "y out of range: {} >= {}", y, upper_bound);
144        self.count_range[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145    }
146
147    pub fn request_xor(&self, x: u32, y: u32) -> u32 {
148        let upper_bound = 1 << NUM_BITS;
149        debug_assert!(x < upper_bound, "x out of range: {} >= {}", x, upper_bound);
150        debug_assert!(y < upper_bound, "y out of range: {} >= {}", y, upper_bound);
151        self.count_xor[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
152        x ^ y
153    }
154
155    pub fn clear(&self) {
156        for i in 0..self.count_range.len() {
157            self.count_range[i].store(0, std::sync::atomic::Ordering::Relaxed);
158            self.count_xor[i].store(0, std::sync::atomic::Ordering::Relaxed);
159        }
160    }
161
162    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
163        let mut rows = F::zero_vec(self.count_range.len() * NUM_BITWISE_OP_LOOKUP_COLS);
164        for (n, row) in rows.chunks_mut(NUM_BITWISE_OP_LOOKUP_COLS).enumerate() {
165            let cols: &mut BitwiseOperationLookupCols<F> = row.borrow_mut();
166            cols.mult_range = F::from_canonical_u32(
167                self.count_range[n].load(std::sync::atomic::Ordering::SeqCst),
168            );
169            cols.mult_xor =
170                F::from_canonical_u32(self.count_xor[n].load(std::sync::atomic::Ordering::SeqCst));
171        }
172        RowMajorMatrix::new(rows, NUM_BITWISE_OP_LOOKUP_COLS)
173    }
174
175    fn idx(x: u32, y: u32) -> usize {
176        (x * (1 << NUM_BITS) + y) as usize
177    }
178}
179
180impl<const NUM_BITS: usize> SharedBitwiseOperationLookupChip<NUM_BITS> {
181    pub fn new(bus: BitwiseOperationLookupBus) -> Self {
182        Self(Arc::new(BitwiseOperationLookupChip::new(bus)))
183    }
184    pub fn bus(&self) -> BitwiseOperationLookupBus {
185        self.0.bus()
186    }
187
188    pub fn air_width(&self) -> usize {
189        self.0.air_width()
190    }
191
192    pub fn request_range(&self, x: u32, y: u32) {
193        self.0.request_range(x, y);
194    }
195
196    pub fn request_xor(&self, x: u32, y: u32) -> u32 {
197        self.0.request_xor(x, y)
198    }
199
200    pub fn clear(&self) {
201        self.0.clear()
202    }
203
204    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
205        self.0.generate_trace()
206    }
207}
208
209impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
210    for BitwiseOperationLookupChip<NUM_BITS>
211{
212    fn air(&self) -> AirRef<SC> {
213        Arc::new(self.air)
214    }
215
216    fn generate_air_proof_input(self) -> AirProofInput<SC> {
217        let trace = self.generate_trace::<Val<SC>>();
218        AirProofInput::simple_no_pis(trace)
219    }
220}
221
222impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
223    for SharedBitwiseOperationLookupChip<NUM_BITS>
224{
225    fn air(&self) -> AirRef<SC> {
226        self.0.air()
227    }
228
229    fn generate_air_proof_input(self) -> AirProofInput<SC> {
230        self.0.generate_air_proof_input()
231    }
232}
233
234impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_BITS> {
235    fn air_name(&self) -> String {
236        get_air_name(&self.air)
237    }
238    fn constant_trace_height(&self) -> Option<usize> {
239        Some(1 << (2 * NUM_BITS))
240    }
241    fn current_trace_height(&self) -> usize {
242        1 << (2 * NUM_BITS)
243    }
244    fn trace_width(&self) -> usize {
245        NUM_BITWISE_OP_LOOKUP_COLS
246    }
247}
248
249impl<const NUM_BITS: usize> ChipUsageGetter for SharedBitwiseOperationLookupChip<NUM_BITS> {
250    fn air_name(&self) -> String {
251        self.0.air_name()
252    }
253
254    fn constant_trace_height(&self) -> Option<usize> {
255        self.0.constant_trace_height()
256    }
257
258    fn current_trace_height(&self) -> usize {
259        self.0.current_trace_height()
260    }
261
262    fn trace_width(&self) -> usize {
263        self.0.trace_width()
264    }
265}
266
267impl<const NUM_BITS: usize> AsRef<BitwiseOperationLookupChip<NUM_BITS>>
268    for SharedBitwiseOperationLookupChip<NUM_BITS>
269{
270    fn as_ref(&self) -> &BitwiseOperationLookupChip<NUM_BITS> {
271        &self.0
272    }
273}