openvm_circuit_primitives/xor/lookup/
mod.rs
1use std::{
7 borrow::Borrow,
8 mem::size_of,
9 sync::{
10 atomic::{self, AtomicU32},
11 Arc,
12 },
13};
14
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_stark_backend::{
17 config::{StarkGenericConfig, Val},
18 interaction::{BusIndex, InteractionBuilder, LookupBus},
19 p3_air::{Air, BaseAir, PairBuilder},
20 p3_field::Field,
21 p3_matrix::{dense::RowMajorMatrix, Matrix},
22 prover::types::AirProofInput,
23 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
24 AirRef, Chip, ChipUsageGetter,
25};
26
27use super::bus::XorBus;
28
29#[cfg(test)]
30mod tests;
31
32#[repr(C)]
34#[derive(Copy, Clone, Debug, AlignedBorrow)]
35pub struct XorLookupCols<T> {
36 pub mult: T,
38}
39
40#[repr(C)]
42#[derive(Copy, Clone, Debug, AlignedBorrow)]
43pub struct XorLookupPreprocessedCols<T> {
44 pub x: T,
45 pub y: T,
46 pub z: T,
48}
49
50pub const NUM_XOR_LOOKUP_COLS: usize = size_of::<XorLookupCols<u8>>();
51pub const NUM_XOR_LOOKUP_PREPROCESSED_COLS: usize = size_of::<XorLookupPreprocessedCols<u8>>();
52
53#[derive(Clone, Copy, Debug, derive_new::new)]
55pub struct XorLookupAir<const M: usize> {
56 pub bus: XorBus,
57}
58
59impl<F: Field, const M: usize> BaseAirWithPublicValues<F> for XorLookupAir<M> {}
60impl<F: Field, const M: usize> PartitionedBaseAir<F> for XorLookupAir<M> {}
61impl<F: Field, const M: usize> BaseAir<F> for XorLookupAir<M> {
62 fn width(&self) -> usize {
63 NUM_XOR_LOOKUP_COLS
64 }
65
66 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
68 let rows: Vec<_> = (0..(1 << M) * (1 << M))
69 .flat_map(|i| {
70 let x = i / (1 << M);
71 let y = i % (1 << M);
72 let z = x ^ y;
73 [x, y, z].map(F::from_canonical_u32)
74 })
75 .collect();
76
77 Some(RowMajorMatrix::new(rows, NUM_XOR_LOOKUP_PREPROCESSED_COLS))
78 }
79}
80
81impl<AB, const M: usize> Air<AB> for XorLookupAir<M>
82where
83 AB: InteractionBuilder + PairBuilder,
84{
85 fn eval(&self, builder: &mut AB) {
86 let main = builder.main();
87 let preprocessed = builder.preprocessed();
88
89 let prep_local = preprocessed.row_slice(0);
90 let prep_local: &XorLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
91 let local = main.row_slice(0);
92 let local: &XorLookupCols<AB::Var> = (*local).borrow();
93
94 self.bus
95 .receive(prep_local.x, prep_local.y, prep_local.z)
96 .eval(builder, local.mult);
97 }
98}
99
100#[derive(Debug)]
104pub struct XorLookupChip<const M: usize> {
105 pub air: XorLookupAir<M>,
106 pub count: Vec<Vec<AtomicU32>>,
108}
109
110impl<const M: usize> XorLookupChip<M> {
111 pub fn new(bus: BusIndex) -> Self {
112 let mut count = vec![];
113 for _ in 0..(1 << M) {
114 let mut row = vec![];
115 for _ in 0..(1 << M) {
116 row.push(AtomicU32::new(0));
117 }
118 count.push(row);
119 }
120 Self {
121 air: XorLookupAir::new(XorBus(LookupBus::new(bus))),
122 count,
123 }
124 }
125
126 pub fn bus(&self) -> XorBus {
128 self.air.bus
129 }
130
131 fn calc_xor(&self, x: u32, y: u32) -> u32 {
132 x ^ y
133 }
134
135 pub fn request(&self, x: u32, y: u32) -> u32 {
138 let val_atomic = &self.count[x as usize][y as usize];
139 val_atomic.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
140
141 self.calc_xor(x, y)
142 }
143
144 pub fn clear(&self) {
146 for i in 0..(1 << M) {
147 for j in 0..(1 << M) {
148 self.count[i][j].store(0, std::sync::atomic::Ordering::Relaxed);
149 }
150 }
151 }
152
153 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
155 debug_assert_eq!(self.count.len(), 1 << M);
156 let multiplicities: Vec<_> = self
157 .count
158 .iter()
159 .flat_map(|count_x| {
160 debug_assert_eq!(count_x.len(), 1 << M);
161 count_x
162 .iter()
163 .map(|count_xy| F::from_canonical_u32(count_xy.load(atomic::Ordering::SeqCst)))
164 })
165 .collect();
166
167 RowMajorMatrix::new_col(multiplicities)
168 }
169}
170
171impl<SC: StarkGenericConfig, const M: usize> Chip<SC> for XorLookupChip<M> {
172 fn air(&self) -> AirRef<SC> {
173 Arc::new(self.air)
174 }
175
176 fn generate_air_proof_input(self) -> AirProofInput<SC> {
177 let trace = self.generate_trace::<Val<SC>>();
178 AirProofInput::simple_no_pis(trace)
179 }
180}
181
182impl<const M: usize> ChipUsageGetter for XorLookupChip<M> {
183 fn air_name(&self) -> String {
184 get_air_name(&self.air)
185 }
186
187 fn current_trace_height(&self) -> usize {
188 1 << (2 * M)
189 }
190
191 fn trace_width(&self) -> usize {
192 NUM_XOR_LOOKUP_COLS
193 }
194}