openvm_circuit_primitives/bitwise_op_lookup/
mod.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4    sync::{atomic::AtomicU32, Arc},
5};
6
7use openvm_circuit_primitives_derive::AlignedBorrow;
8use openvm_stark_backend::{
9    config::{StarkGenericConfig, Val},
10    interaction::InteractionBuilder,
11    p3_air::{Air, BaseAir, PairBuilder},
12    p3_field::{Field, FieldAlgebra},
13    p3_matrix::{dense::RowMajorMatrix, Matrix},
14    prover::{cpu::CpuBackend, types::AirProvingContext},
15    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
16    Chip, ChipUsageGetter,
17};
18
19mod bus;
20pub use bus::*;
21
22#[cfg(feature = "cuda")]
23mod cuda;
24#[cfg(feature = "cuda")]
25pub use cuda::*;
26
27#[cfg(test)]
28mod tests;
29
30#[derive(Default, AlignedBorrow, Copy, Clone)]
31#[repr(C)]
32pub struct BitwiseOperationLookupCols<T> {
33    /// Number of range check operations requested for each (x, y) pair
34    pub mult_range: T,
35    /// Number of XOR operations requested for each (x, y) pair
36    pub mult_xor: T,
37}
38
39#[derive(Default, AlignedBorrow, Copy, Clone)]
40#[repr(C)]
41pub struct BitwiseOperationLookupPreprocessedCols<T> {
42    pub x: T,
43    pub y: T,
44    /// XOR result of x and y (x ⊕ y)
45    pub z_xor: T,
46}
47
48pub const NUM_BITWISE_OP_LOOKUP_COLS: usize = size_of::<BitwiseOperationLookupCols<u8>>();
49pub const NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS: usize =
50    size_of::<BitwiseOperationLookupPreprocessedCols<u8>>();
51
52#[derive(Clone, Copy, Debug, derive_new::new)]
53pub struct BitwiseOperationLookupAir<const NUM_BITS: usize> {
54    pub bus: BitwiseOperationLookupBus,
55}
56
57impl<F: Field, const NUM_BITS: usize> BaseAirWithPublicValues<F>
58    for BitwiseOperationLookupAir<NUM_BITS>
59{
60}
61impl<F: Field, const NUM_BITS: usize> PartitionedBaseAir<F>
62    for BitwiseOperationLookupAir<NUM_BITS>
63{
64}
65impl<F: Field, const NUM_BITS: usize> BaseAir<F> for BitwiseOperationLookupAir<NUM_BITS> {
66    fn width(&self) -> usize {
67        NUM_BITWISE_OP_LOOKUP_COLS
68    }
69
70    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
71        let rows: Vec<F> = (0..(1 << NUM_BITS))
72            .flat_map(|x: u32| {
73                (0..(1 << NUM_BITS)).flat_map(move |y: u32| {
74                    [
75                        F::from_canonical_u32(x),
76                        F::from_canonical_u32(y),
77                        F::from_canonical_u32(x ^ y),
78                    ]
79                })
80            })
81            .collect();
82        Some(RowMajorMatrix::new(
83            rows,
84            NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS,
85        ))
86    }
87}
88
89impl<AB: InteractionBuilder + PairBuilder, const NUM_BITS: usize> Air<AB>
90    for BitwiseOperationLookupAir<NUM_BITS>
91{
92    fn eval(&self, builder: &mut AB) {
93        let preprocessed = builder.preprocessed();
94        let prep_local = preprocessed.row_slice(0);
95        let prep_local: &BitwiseOperationLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
96
97        let main = builder.main();
98        let local = main.row_slice(0);
99        let local: &BitwiseOperationLookupCols<AB::Var> = (*local).borrow();
100
101        self.bus
102            .receive(prep_local.x, prep_local.y, AB::F::ZERO, AB::F::ZERO)
103            .eval(builder, local.mult_range);
104        self.bus
105            .receive(prep_local.x, prep_local.y, prep_local.z_xor, AB::F::ONE)
106            .eval(builder, local.mult_xor);
107    }
108}
109
110// Lookup chip for operations on size NUM_BITS integers. Currently has pre-processed columns
111// for x ^ y and range check. Interactions are of form [x, y, z] where z is either x ^ y for
112// XOR or 0 for range check.
113
114pub struct BitwiseOperationLookupChip<const NUM_BITS: usize> {
115    pub air: BitwiseOperationLookupAir<NUM_BITS>,
116    pub count_range: Vec<AtomicU32>,
117    pub count_xor: Vec<AtomicU32>,
118}
119
120pub type SharedBitwiseOperationLookupChip<const NUM_BITS: usize> =
121    Arc<BitwiseOperationLookupChip<NUM_BITS>>;
122
123impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
124    pub fn new(bus: BitwiseOperationLookupBus) -> Self {
125        let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS);
126        let count_range = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
127        let count_xor = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
128        Self {
129            air: BitwiseOperationLookupAir::new(bus),
130            count_range,
131            count_xor,
132        }
133    }
134
135    pub fn bus(&self) -> BitwiseOperationLookupBus {
136        self.air.bus
137    }
138
139    pub fn air_width(&self) -> usize {
140        NUM_BITWISE_OP_LOOKUP_COLS
141    }
142
143    pub fn request_range(&self, x: u32, y: u32) {
144        let upper_bound = 1 << NUM_BITS;
145        debug_assert!(x < upper_bound, "x out of range: {} >= {}", x, upper_bound);
146        debug_assert!(y < upper_bound, "y out of range: {} >= {}", y, upper_bound);
147        self.count_range[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148    }
149
150    pub fn request_xor(&self, x: u32, y: u32) -> u32 {
151        let upper_bound = 1 << NUM_BITS;
152        debug_assert!(x < upper_bound, "x out of range: {} >= {}", x, upper_bound);
153        debug_assert!(y < upper_bound, "y out of range: {} >= {}", y, upper_bound);
154        self.count_xor[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
155        x ^ y
156    }
157
158    pub fn clear(&self) {
159        for i in 0..self.count_range.len() {
160            self.count_range[i].store(0, std::sync::atomic::Ordering::Relaxed);
161            self.count_xor[i].store(0, std::sync::atomic::Ordering::Relaxed);
162        }
163    }
164
165    /// Generates trace and resets all internal counters to 0.
166    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
167        let mut rows = F::zero_vec(self.count_range.len() * NUM_BITWISE_OP_LOOKUP_COLS);
168        for (n, row) in rows.chunks_mut(NUM_BITWISE_OP_LOOKUP_COLS).enumerate() {
169            let cols: &mut BitwiseOperationLookupCols<F> = row.borrow_mut();
170            cols.mult_range = F::from_canonical_u32(
171                self.count_range[n].swap(0, std::sync::atomic::Ordering::SeqCst),
172            );
173            cols.mult_xor = F::from_canonical_u32(
174                self.count_xor[n].swap(0, std::sync::atomic::Ordering::SeqCst),
175            );
176        }
177        RowMajorMatrix::new(rows, NUM_BITWISE_OP_LOOKUP_COLS)
178    }
179
180    fn idx(x: u32, y: u32) -> usize {
181        (x * (1 << NUM_BITS) + y) as usize
182    }
183}
184
185impl<R, SC: StarkGenericConfig, const NUM_BITS: usize> Chip<R, CpuBackend<SC>>
186    for BitwiseOperationLookupChip<NUM_BITS>
187{
188    /// Generates trace and resets all internal counters to 0.
189    fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
190        let trace = self.generate_trace::<Val<SC>>();
191        AirProvingContext::simple_no_pis(Arc::new(trace))
192    }
193}
194
195impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_BITS> {
196    fn air_name(&self) -> String {
197        get_air_name(&self.air)
198    }
199    fn constant_trace_height(&self) -> Option<usize> {
200        Some(1 << (2 * NUM_BITS))
201    }
202    fn current_trace_height(&self) -> usize {
203        1 << (2 * NUM_BITS)
204    }
205    fn trace_width(&self) -> usize {
206        NUM_BITWISE_OP_LOOKUP_COLS
207    }
208}