openvm_circuit_primitives/xor/lookup/
mod.rs1use std::{
7 borrow::Borrow,
8 mem::size_of,
9 sync::{
10 atomic::{self, AtomicU32},
11 Arc,
12 },
13};
14
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_stark_backend::{
17 config::{StarkGenericConfig, Val},
18 interaction::{BusIndex, InteractionBuilder, LookupBus},
19 p3_air::{Air, BaseAir, PairBuilder},
20 p3_field::Field,
21 p3_matrix::{dense::RowMajorMatrix, Matrix},
22 prover::{cpu::CpuBackend, types::AirProvingContext},
23 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
24 Chip, ChipUsageGetter,
25};
26
27use super::bus::XorBus;
28
29#[cfg(test)]
30mod tests;
31
32#[repr(C)]
34#[derive(Copy, Clone, Debug, AlignedBorrow)]
35pub struct XorLookupCols<T> {
36 pub mult: T,
38}
39
40#[repr(C)]
42#[derive(Copy, Clone, Debug, AlignedBorrow)]
43pub struct XorLookupPreprocessedCols<T> {
44 pub x: T,
45 pub y: T,
46 pub z: T,
48}
49
50pub const NUM_XOR_LOOKUP_COLS: usize = size_of::<XorLookupCols<u8>>();
51pub const NUM_XOR_LOOKUP_PREPROCESSED_COLS: usize = size_of::<XorLookupPreprocessedCols<u8>>();
52
53#[derive(Clone, Copy, Debug, derive_new::new)]
56pub struct XorLookupAir<const M: usize> {
57 pub bus: XorBus,
58}
59
60impl<F: Field, const M: usize> BaseAirWithPublicValues<F> for XorLookupAir<M> {}
61impl<F: Field, const M: usize> PartitionedBaseAir<F> for XorLookupAir<M> {}
62impl<F: Field, const M: usize> BaseAir<F> for XorLookupAir<M> {
63 fn width(&self) -> usize {
64 NUM_XOR_LOOKUP_COLS
65 }
66
67 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
69 let rows: Vec<_> = (0..(1 << M) * (1 << M))
70 .flat_map(|i| {
71 let x = i / (1 << M);
72 let y = i % (1 << M);
73 let z = x ^ y;
74 [x, y, z].map(F::from_canonical_u32)
75 })
76 .collect();
77
78 Some(RowMajorMatrix::new(rows, NUM_XOR_LOOKUP_PREPROCESSED_COLS))
79 }
80}
81
82impl<AB, const M: usize> Air<AB> for XorLookupAir<M>
83where
84 AB: InteractionBuilder + PairBuilder,
85{
86 fn eval(&self, builder: &mut AB) {
87 let main = builder.main();
88 let preprocessed = builder.preprocessed();
89
90 let prep_local = preprocessed.row_slice(0);
91 let prep_local: &XorLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
92 let local = main.row_slice(0);
93 let local: &XorLookupCols<AB::Var> = (*local).borrow();
94
95 self.bus
96 .receive(prep_local.x, prep_local.y, prep_local.z)
97 .eval(builder, local.mult);
98 }
99}
100
101#[derive(Debug)]
106pub struct XorLookupChip<const M: usize> {
107 pub air: XorLookupAir<M>,
108 pub count: Vec<Vec<AtomicU32>>,
110}
111
112impl<const M: usize> XorLookupChip<M> {
113 pub fn new(bus: BusIndex) -> Self {
114 let mut count = vec![];
115 for _ in 0..(1 << M) {
116 let mut row = vec![];
117 for _ in 0..(1 << M) {
118 row.push(AtomicU32::new(0));
119 }
120 count.push(row);
121 }
122 Self {
123 air: XorLookupAir::new(XorBus(LookupBus::new(bus))),
124 count,
125 }
126 }
127
128 pub fn bus(&self) -> XorBus {
130 self.air.bus
131 }
132
133 fn calc_xor(&self, x: u32, y: u32) -> u32 {
134 x ^ y
135 }
136
137 pub fn request(&self, x: u32, y: u32) -> u32 {
140 let val_atomic = &self.count[x as usize][y as usize];
141 val_atomic.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
142
143 self.calc_xor(x, y)
144 }
145
146 pub fn clear(&self) {
148 for i in 0..(1 << M) {
149 for j in 0..(1 << M) {
150 self.count[i][j].store(0, std::sync::atomic::Ordering::Relaxed);
151 }
152 }
153 }
154
155 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
157 debug_assert_eq!(self.count.len(), 1 << M);
158 let multiplicities: Vec<_> = self
159 .count
160 .iter()
161 .flat_map(|count_x| {
162 debug_assert_eq!(count_x.len(), 1 << M);
163 count_x
164 .iter()
165 .map(|count_xy| F::from_canonical_u32(count_xy.load(atomic::Ordering::SeqCst)))
166 })
167 .collect();
168
169 RowMajorMatrix::new_col(multiplicities)
170 }
171}
172
173impl<R, SC: StarkGenericConfig, const M: usize> Chip<R, CpuBackend<SC>> for XorLookupChip<M> {
174 fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
175 let trace = self.generate_trace::<Val<SC>>();
176 AirProvingContext::simple_no_pis(Arc::new(trace))
177 }
178}
179
180impl<const M: usize> ChipUsageGetter for XorLookupChip<M> {
181 fn air_name(&self) -> String {
182 get_air_name(&self.air)
183 }
184
185 fn current_trace_height(&self) -> usize {
186 1 << (2 * M)
187 }
188
189 fn trace_width(&self) -> usize {
190 NUM_XOR_LOOKUP_COLS
191 }
192}