openvm_circuit_primitives/bitwise_op_lookup/
mod.rs
1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4 sync::{atomic::AtomicU32, Arc},
5};
6
7use openvm_circuit_primitives_derive::AlignedBorrow;
8use openvm_stark_backend::{
9 config::{StarkGenericConfig, Val},
10 interaction::InteractionBuilder,
11 p3_air::{Air, BaseAir, PairBuilder},
12 p3_field::{Field, FieldAlgebra},
13 p3_matrix::{dense::RowMajorMatrix, Matrix},
14 prover::types::AirProofInput,
15 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
16 AirRef, Chip, ChipUsageGetter,
17};
18
19mod bus;
20#[cfg(test)]
21mod tests;
22
23pub use bus::*;
24
25#[derive(Default, AlignedBorrow, Copy, Clone)]
26#[repr(C)]
27pub struct BitwiseOperationLookupCols<T> {
28 pub mult_range: T,
30 pub mult_xor: T,
32}
33
34#[derive(Default, AlignedBorrow, Copy, Clone)]
35#[repr(C)]
36pub struct BitwiseOperationLookupPreprocessedCols<T> {
37 pub x: T,
38 pub y: T,
39 pub z_xor: T,
41}
42
43pub const NUM_BITWISE_OP_LOOKUP_COLS: usize = size_of::<BitwiseOperationLookupCols<u8>>();
44pub const NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS: usize =
45 size_of::<BitwiseOperationLookupPreprocessedCols<u8>>();
46
47#[derive(Clone, Copy, Debug, derive_new::new)]
48pub struct BitwiseOperationLookupAir<const NUM_BITS: usize> {
49 pub bus: BitwiseOperationLookupBus,
50}
51
52impl<F: Field, const NUM_BITS: usize> BaseAirWithPublicValues<F>
53 for BitwiseOperationLookupAir<NUM_BITS>
54{
55}
56impl<F: Field, const NUM_BITS: usize> PartitionedBaseAir<F>
57 for BitwiseOperationLookupAir<NUM_BITS>
58{
59}
60impl<F: Field, const NUM_BITS: usize> BaseAir<F> for BitwiseOperationLookupAir<NUM_BITS> {
61 fn width(&self) -> usize {
62 NUM_BITWISE_OP_LOOKUP_COLS
63 }
64
65 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
66 let rows: Vec<F> = (0..(1 << NUM_BITS))
67 .flat_map(|x: u32| {
68 (0..(1 << NUM_BITS)).flat_map(move |y: u32| {
69 [
70 F::from_canonical_u32(x),
71 F::from_canonical_u32(y),
72 F::from_canonical_u32(x ^ y),
73 ]
74 })
75 })
76 .collect();
77 Some(RowMajorMatrix::new(
78 rows,
79 NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS,
80 ))
81 }
82}
83
84impl<AB: InteractionBuilder + PairBuilder, const NUM_BITS: usize> Air<AB>
85 for BitwiseOperationLookupAir<NUM_BITS>
86{
87 fn eval(&self, builder: &mut AB) {
88 let preprocessed = builder.preprocessed();
89 let prep_local = preprocessed.row_slice(0);
90 let prep_local: &BitwiseOperationLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
91
92 let main = builder.main();
93 let local = main.row_slice(0);
94 let local: &BitwiseOperationLookupCols<AB::Var> = (*local).borrow();
95
96 self.bus
97 .receive(prep_local.x, prep_local.y, AB::F::ZERO, AB::F::ZERO)
98 .eval(builder, local.mult_range);
99 self.bus
100 .receive(prep_local.x, prep_local.y, prep_local.z_xor, AB::F::ONE)
101 .eval(builder, local.mult_xor);
102 }
103}
104
105pub struct BitwiseOperationLookupChip<const NUM_BITS: usize> {
110 pub air: BitwiseOperationLookupAir<NUM_BITS>,
111 pub count_range: Vec<AtomicU32>,
112 pub count_xor: Vec<AtomicU32>,
113}
114
115#[derive(Clone)]
116pub struct SharedBitwiseOperationLookupChip<const NUM_BITS: usize>(
117 Arc<BitwiseOperationLookupChip<NUM_BITS>>,
118);
119
120impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
121 pub fn new(bus: BitwiseOperationLookupBus) -> Self {
122 let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS);
123 let count_range = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
124 let count_xor = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
125 Self {
126 air: BitwiseOperationLookupAir::new(bus),
127 count_range,
128 count_xor,
129 }
130 }
131
132 pub fn bus(&self) -> BitwiseOperationLookupBus {
133 self.air.bus
134 }
135
136 pub fn air_width(&self) -> usize {
137 NUM_BITWISE_OP_LOOKUP_COLS
138 }
139
140 pub fn request_range(&self, x: u32, y: u32) {
141 let upper_bound = 1 << NUM_BITS;
142 debug_assert!(x < upper_bound, "x out of range: {} >= {}", x, upper_bound);
143 debug_assert!(y < upper_bound, "y out of range: {} >= {}", y, upper_bound);
144 self.count_range[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145 }
146
147 pub fn request_xor(&self, x: u32, y: u32) -> u32 {
148 let upper_bound = 1 << NUM_BITS;
149 debug_assert!(x < upper_bound, "x out of range: {} >= {}", x, upper_bound);
150 debug_assert!(y < upper_bound, "y out of range: {} >= {}", y, upper_bound);
151 self.count_xor[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
152 x ^ y
153 }
154
155 pub fn clear(&self) {
156 for i in 0..self.count_range.len() {
157 self.count_range[i].store(0, std::sync::atomic::Ordering::Relaxed);
158 self.count_xor[i].store(0, std::sync::atomic::Ordering::Relaxed);
159 }
160 }
161
162 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
163 let mut rows = F::zero_vec(self.count_range.len() * NUM_BITWISE_OP_LOOKUP_COLS);
164 for (n, row) in rows.chunks_mut(NUM_BITWISE_OP_LOOKUP_COLS).enumerate() {
165 let cols: &mut BitwiseOperationLookupCols<F> = row.borrow_mut();
166 cols.mult_range = F::from_canonical_u32(
167 self.count_range[n].load(std::sync::atomic::Ordering::SeqCst),
168 );
169 cols.mult_xor =
170 F::from_canonical_u32(self.count_xor[n].load(std::sync::atomic::Ordering::SeqCst));
171 }
172 RowMajorMatrix::new(rows, NUM_BITWISE_OP_LOOKUP_COLS)
173 }
174
175 fn idx(x: u32, y: u32) -> usize {
176 (x * (1 << NUM_BITS) + y) as usize
177 }
178}
179
180impl<const NUM_BITS: usize> SharedBitwiseOperationLookupChip<NUM_BITS> {
181 pub fn new(bus: BitwiseOperationLookupBus) -> Self {
182 Self(Arc::new(BitwiseOperationLookupChip::new(bus)))
183 }
184 pub fn bus(&self) -> BitwiseOperationLookupBus {
185 self.0.bus()
186 }
187
188 pub fn air_width(&self) -> usize {
189 self.0.air_width()
190 }
191
192 pub fn request_range(&self, x: u32, y: u32) {
193 self.0.request_range(x, y);
194 }
195
196 pub fn request_xor(&self, x: u32, y: u32) -> u32 {
197 self.0.request_xor(x, y)
198 }
199
200 pub fn clear(&self) {
201 self.0.clear()
202 }
203
204 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
205 self.0.generate_trace()
206 }
207}
208
209impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
210 for BitwiseOperationLookupChip<NUM_BITS>
211{
212 fn air(&self) -> AirRef<SC> {
213 Arc::new(self.air)
214 }
215
216 fn generate_air_proof_input(self) -> AirProofInput<SC> {
217 let trace = self.generate_trace::<Val<SC>>();
218 AirProofInput::simple_no_pis(trace)
219 }
220}
221
222impl<SC: StarkGenericConfig, const NUM_BITS: usize> Chip<SC>
223 for SharedBitwiseOperationLookupChip<NUM_BITS>
224{
225 fn air(&self) -> AirRef<SC> {
226 self.0.air()
227 }
228
229 fn generate_air_proof_input(self) -> AirProofInput<SC> {
230 self.0.generate_air_proof_input()
231 }
232}
233
234impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_BITS> {
235 fn air_name(&self) -> String {
236 get_air_name(&self.air)
237 }
238 fn constant_trace_height(&self) -> Option<usize> {
239 Some(1 << (2 * NUM_BITS))
240 }
241 fn current_trace_height(&self) -> usize {
242 1 << (2 * NUM_BITS)
243 }
244 fn trace_width(&self) -> usize {
245 NUM_BITWISE_OP_LOOKUP_COLS
246 }
247}
248
249impl<const NUM_BITS: usize> ChipUsageGetter for SharedBitwiseOperationLookupChip<NUM_BITS> {
250 fn air_name(&self) -> String {
251 self.0.air_name()
252 }
253
254 fn constant_trace_height(&self) -> Option<usize> {
255 self.0.constant_trace_height()
256 }
257
258 fn current_trace_height(&self) -> usize {
259 self.0.current_trace_height()
260 }
261
262 fn trace_width(&self) -> usize {
263 self.0.trace_width()
264 }
265}
266
267impl<const NUM_BITS: usize> AsRef<BitwiseOperationLookupChip<NUM_BITS>>
268 for SharedBitwiseOperationLookupChip<NUM_BITS>
269{
270 fn as_ref(&self) -> &BitwiseOperationLookupChip<NUM_BITS> {
271 &self.0
272 }
273}