openvm_circuit_primitives/bitwise_op_lookup/
mod.rs1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4 sync::{atomic::AtomicU32, Arc},
5};
6
7use openvm_circuit_primitives_derive::AlignedBorrow;
8use openvm_stark_backend::{
9 config::{StarkGenericConfig, Val},
10 interaction::InteractionBuilder,
11 p3_air::{Air, BaseAir, PairBuilder},
12 p3_field::{Field, PrimeCharacteristicRing},
13 p3_matrix::{dense::RowMajorMatrix, Matrix},
14 prover::{cpu::CpuBackend, types::AirProvingContext},
15 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
16 Chip, ChipUsageGetter,
17};
18
19mod bus;
20pub use bus::*;
21
22#[cfg(feature = "cuda")]
23mod cuda;
24#[cfg(feature = "cuda")]
25pub use cuda::*;
26
27#[cfg(test)]
28mod tests;
29
30#[derive(Default, AlignedBorrow, Copy, Clone)]
31#[repr(C)]
32pub struct BitwiseOperationLookupCols<T> {
33 pub mult_range: T,
35 pub mult_xor: T,
37}
38
39#[derive(Default, AlignedBorrow, Copy, Clone)]
40#[repr(C)]
41pub struct BitwiseOperationLookupPreprocessedCols<T> {
42 pub x: T,
43 pub y: T,
44 pub z_xor: T,
46}
47
48pub const NUM_BITWISE_OP_LOOKUP_COLS: usize = size_of::<BitwiseOperationLookupCols<u8>>();
49pub const NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS: usize =
50 size_of::<BitwiseOperationLookupPreprocessedCols<u8>>();
51
52#[derive(Clone, Copy, Debug, derive_new::new)]
53pub struct BitwiseOperationLookupAir<const NUM_BITS: usize> {
54 pub bus: BitwiseOperationLookupBus,
55}
56
57impl<F: Field, const NUM_BITS: usize> BaseAirWithPublicValues<F>
58 for BitwiseOperationLookupAir<NUM_BITS>
59{
60}
61impl<F: Field, const NUM_BITS: usize> PartitionedBaseAir<F>
62 for BitwiseOperationLookupAir<NUM_BITS>
63{
64}
65impl<F: Field, const NUM_BITS: usize> BaseAir<F> for BitwiseOperationLookupAir<NUM_BITS> {
66 fn width(&self) -> usize {
67 NUM_BITWISE_OP_LOOKUP_COLS
68 }
69
70 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
71 let rows: Vec<F> = (0..(1 << NUM_BITS))
72 .flat_map(|x: u32| {
73 (0..(1 << NUM_BITS))
74 .flat_map(move |y: u32| [F::from_u32(x), F::from_u32(y), F::from_u32(x ^ y)])
75 })
76 .collect();
77 Some(RowMajorMatrix::new(
78 rows,
79 NUM_BITWISE_OP_LOOKUP_PREPROCESSED_COLS,
80 ))
81 }
82}
83
84impl<AB: InteractionBuilder + PairBuilder, const NUM_BITS: usize> Air<AB>
85 for BitwiseOperationLookupAir<NUM_BITS>
86{
87 fn eval(&self, builder: &mut AB) {
88 let preprocessed = builder.preprocessed();
89 let prep_local = preprocessed
90 .row_slice(0)
91 .expect("window should have two elements");
92 let prep_local: &BitwiseOperationLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
93
94 let main = builder.main();
95 let local = main.row_slice(0).expect("window should have two elements");
96 let local: &BitwiseOperationLookupCols<AB::Var> = (*local).borrow();
97
98 self.bus
99 .receive(prep_local.x, prep_local.y, AB::F::ZERO, AB::F::ZERO)
100 .eval(builder, local.mult_range);
101 self.bus
102 .receive(prep_local.x, prep_local.y, prep_local.z_xor, AB::F::ONE)
103 .eval(builder, local.mult_xor);
104 }
105}
106
107pub struct BitwiseOperationLookupChip<const NUM_BITS: usize> {
112 pub air: BitwiseOperationLookupAir<NUM_BITS>,
113 pub count_range: Vec<AtomicU32>,
114 pub count_xor: Vec<AtomicU32>,
115}
116
117pub type SharedBitwiseOperationLookupChip<const NUM_BITS: usize> =
118 Arc<BitwiseOperationLookupChip<NUM_BITS>>;
119
120impl<const NUM_BITS: usize> BitwiseOperationLookupChip<NUM_BITS> {
121 pub fn new(bus: BitwiseOperationLookupBus) -> Self {
122 let num_rows = (1 << NUM_BITS) * (1 << NUM_BITS);
123 let count_range = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
124 let count_xor = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
125 Self {
126 air: BitwiseOperationLookupAir::new(bus),
127 count_range,
128 count_xor,
129 }
130 }
131
132 pub fn bus(&self) -> BitwiseOperationLookupBus {
133 self.air.bus
134 }
135
136 pub fn air_width(&self) -> usize {
137 NUM_BITWISE_OP_LOOKUP_COLS
138 }
139
140 pub fn request_range(&self, x: u32, y: u32) {
141 let upper_bound = 1 << NUM_BITS;
142 debug_assert!(x < upper_bound, "x out of range: {x} >= {upper_bound}");
143 debug_assert!(y < upper_bound, "y out of range: {y} >= {upper_bound}");
144 self.count_range[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145 }
146
147 pub fn request_xor(&self, x: u32, y: u32) -> u32 {
148 let upper_bound = 1 << NUM_BITS;
149 debug_assert!(x < upper_bound, "x out of range: {x} >= {upper_bound}");
150 debug_assert!(y < upper_bound, "y out of range: {y} >= {upper_bound}");
151 self.count_xor[Self::idx(x, y)].fetch_add(1, std::sync::atomic::Ordering::Relaxed);
152 x ^ y
153 }
154
155 pub fn clear(&self) {
156 for i in 0..self.count_range.len() {
157 self.count_range[i].store(0, std::sync::atomic::Ordering::Relaxed);
158 self.count_xor[i].store(0, std::sync::atomic::Ordering::Relaxed);
159 }
160 }
161
162 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
164 let mut rows = F::zero_vec(self.count_range.len() * NUM_BITWISE_OP_LOOKUP_COLS);
165 for (n, row) in rows.chunks_mut(NUM_BITWISE_OP_LOOKUP_COLS).enumerate() {
166 let cols: &mut BitwiseOperationLookupCols<F> = row.borrow_mut();
167 cols.mult_range =
168 F::from_u32(self.count_range[n].swap(0, std::sync::atomic::Ordering::SeqCst));
169 cols.mult_xor =
170 F::from_u32(self.count_xor[n].swap(0, std::sync::atomic::Ordering::SeqCst));
171 }
172 RowMajorMatrix::new(rows, NUM_BITWISE_OP_LOOKUP_COLS)
173 }
174
175 fn idx(x: u32, y: u32) -> usize {
176 (x * (1 << NUM_BITS) + y) as usize
177 }
178}
179
180impl<R, SC: StarkGenericConfig, const NUM_BITS: usize> Chip<R, CpuBackend<SC>>
181 for BitwiseOperationLookupChip<NUM_BITS>
182{
183 fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
185 let trace = self.generate_trace::<Val<SC>>();
186 AirProvingContext::simple_no_pis(Arc::new(trace))
187 }
188}
189
190impl<const NUM_BITS: usize> ChipUsageGetter for BitwiseOperationLookupChip<NUM_BITS> {
191 fn air_name(&self) -> String {
192 get_air_name(&self.air)
193 }
194 fn constant_trace_height(&self) -> Option<usize> {
195 Some(1 << (2 * NUM_BITS))
196 }
197 fn current_trace_height(&self) -> usize {
198 1 << (2 * NUM_BITS)
199 }
200 fn trace_width(&self) -> usize {
201 NUM_BITWISE_OP_LOOKUP_COLS
202 }
203}