openvm_circuit_primitives/xor/lookup/
mod.rs

1//! A chip which uses preprocessed trace to provide a lookup table for XOR operations
2//! between two numbers `x` and `y` of at most `M` bits.
3//! It generates a preprocessed table with a row for each possible triple `(x, y, x^y)`
4//! and keeps count of the number of times each triple is requested.
5
6use std::{
7    borrow::Borrow,
8    mem::size_of,
9    sync::{
10        atomic::{self, AtomicU32},
11        Arc,
12    },
13};
14
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_stark_backend::{
17    config::{StarkGenericConfig, Val},
18    interaction::{BusIndex, InteractionBuilder, LookupBus},
19    p3_air::{Air, BaseAir, PairBuilder},
20    p3_field::Field,
21    p3_matrix::{dense::RowMajorMatrix, Matrix},
22    prover::{cpu::CpuBackend, types::AirProvingContext},
23    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
24    Chip, ChipUsageGetter,
25};
26
27use super::bus::XorBus;
28
29#[cfg(test)]
30mod tests;
31
32/// Columns for the main trace of the XOR lookup
33#[repr(C)]
34#[derive(Copy, Clone, Debug, AlignedBorrow)]
35pub struct XorLookupCols<T> {
36    /// Multiplicity counter tracking the number of XOR operations requested for each triple
37    pub mult: T,
38}
39
40/// Columns for the preprocessed table of the XOR lookup
41#[repr(C)]
42#[derive(Copy, Clone, Debug, AlignedBorrow)]
43pub struct XorLookupPreprocessedCols<T> {
44    pub x: T,
45    pub y: T,
46    /// XOR result (x ⊕ y)
47    pub z: T,
48}
49
50pub const NUM_XOR_LOOKUP_COLS: usize = size_of::<XorLookupCols<u8>>();
51pub const NUM_XOR_LOOKUP_PREPROCESSED_COLS: usize = size_of::<XorLookupPreprocessedCols<u8>>();
52
53/// Xor via preprocessed lookup table. Can only be used if inputs have less than approximately
54/// 10-bits.
55#[derive(Clone, Copy, Debug, derive_new::new)]
56pub struct XorLookupAir<const M: usize> {
57    pub bus: XorBus,
58}
59
60impl<F: Field, const M: usize> BaseAirWithPublicValues<F> for XorLookupAir<M> {}
61impl<F: Field, const M: usize> PartitionedBaseAir<F> for XorLookupAir<M> {}
62impl<F: Field, const M: usize> BaseAir<F> for XorLookupAir<M> {
63    fn width(&self) -> usize {
64        NUM_XOR_LOOKUP_COLS
65    }
66
67    /// Generates a preprocessed table with a row for each possible triple (x, y, x^y)
68    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
69        let rows: Vec<_> = (0..(1 << M) * (1 << M))
70            .flat_map(|i| {
71                let x = i / (1 << M);
72                let y = i % (1 << M);
73                let z = x ^ y;
74                [x, y, z].map(F::from_canonical_u32)
75            })
76            .collect();
77
78        Some(RowMajorMatrix::new(rows, NUM_XOR_LOOKUP_PREPROCESSED_COLS))
79    }
80}
81
82impl<AB, const M: usize> Air<AB> for XorLookupAir<M>
83where
84    AB: InteractionBuilder + PairBuilder,
85{
86    fn eval(&self, builder: &mut AB) {
87        let main = builder.main();
88        let preprocessed = builder.preprocessed();
89
90        let prep_local = preprocessed.row_slice(0);
91        let prep_local: &XorLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
92        let local = main.row_slice(0);
93        let local: &XorLookupCols<AB::Var> = (*local).borrow();
94
95        self.bus
96            .receive(prep_local.x, prep_local.y, prep_local.z)
97            .eval(builder, local.mult);
98    }
99}
100
101/// This chip gets requests to compute the xor of two numbers x and y of at most M bits.
102/// It generates a preprocessed table with a row for each possible triple (x, y, x^y)
103/// and keeps count of the number of times each triple is requested for the single main trace
104/// column.
105#[derive(Debug)]
106pub struct XorLookupChip<const M: usize> {
107    pub air: XorLookupAir<M>,
108    /// Tracks the count of each (x,y) pair requested
109    pub count: Vec<Vec<AtomicU32>>,
110}
111
112impl<const M: usize> XorLookupChip<M> {
113    pub fn new(bus: BusIndex) -> Self {
114        let mut count = vec![];
115        for _ in 0..(1 << M) {
116            let mut row = vec![];
117            for _ in 0..(1 << M) {
118                row.push(AtomicU32::new(0));
119            }
120            count.push(row);
121        }
122        Self {
123            air: XorLookupAir::new(XorBus(LookupBus::new(bus))),
124            count,
125        }
126    }
127
128    /// The xor bus this chip interacts with
129    pub fn bus(&self) -> XorBus {
130        self.air.bus
131    }
132
133    fn calc_xor(&self, x: u32, y: u32) -> u32 {
134        x ^ y
135    }
136
137    /// Request an XOR operation for inputs x and y
138    /// Increments the count for this (x,y) pair and returns x ⊕ y
139    pub fn request(&self, x: u32, y: u32) -> u32 {
140        let val_atomic = &self.count[x as usize][y as usize];
141        val_atomic.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
142
143        self.calc_xor(x, y)
144    }
145
146    /// Resets all request counters to zero
147    pub fn clear(&self) {
148        for i in 0..(1 << M) {
149            for j in 0..(1 << M) {
150                self.count[i][j].store(0, std::sync::atomic::Ordering::Relaxed);
151            }
152        }
153    }
154
155    /// Generates the multiplicity trace based on requests
156    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
157        debug_assert_eq!(self.count.len(), 1 << M);
158        let multiplicities: Vec<_> = self
159            .count
160            .iter()
161            .flat_map(|count_x| {
162                debug_assert_eq!(count_x.len(), 1 << M);
163                count_x
164                    .iter()
165                    .map(|count_xy| F::from_canonical_u32(count_xy.load(atomic::Ordering::SeqCst)))
166            })
167            .collect();
168
169        RowMajorMatrix::new_col(multiplicities)
170    }
171}
172
173impl<R, SC: StarkGenericConfig, const M: usize> Chip<R, CpuBackend<SC>> for XorLookupChip<M> {
174    fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
175        let trace = self.generate_trace::<Val<SC>>();
176        AirProvingContext::simple_no_pis(Arc::new(trace))
177    }
178}
179
180impl<const M: usize> ChipUsageGetter for XorLookupChip<M> {
181    fn air_name(&self) -> String {
182        get_air_name(&self.air)
183    }
184
185    fn current_trace_height(&self) -> usize {
186        1 << (2 * M)
187    }
188
189    fn trace_width(&self) -> usize {
190        NUM_XOR_LOOKUP_COLS
191    }
192}