openvm_circuit_primitives/xor/lookup/
mod.rs

1//! A chip which uses preprocessed trace to provide a lookup table for XOR operations
2//! between two numbers `x` and `y` of at most `M` bits.
3//! It generates a preprocessed table with a row for each possible triple `(x, y, x^y)`
4//! and keeps count of the number of times each triple is requested.
5
6use std::{
7    borrow::Borrow,
8    mem::size_of,
9    sync::{
10        atomic::{self, AtomicU32},
11        Arc,
12    },
13};
14
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_stark_backend::{
17    config::{StarkGenericConfig, Val},
18    interaction::{BusIndex, InteractionBuilder, LookupBus},
19    p3_air::{Air, BaseAir, PairBuilder},
20    p3_field::Field,
21    p3_matrix::{dense::RowMajorMatrix, Matrix},
22    prover::types::AirProofInput,
23    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
24    AirRef, Chip, ChipUsageGetter,
25};
26
27use super::bus::XorBus;
28
29#[cfg(test)]
30mod tests;
31
32/// Columns for the main trace of the XOR lookup
33#[repr(C)]
34#[derive(Copy, Clone, Debug, AlignedBorrow)]
35pub struct XorLookupCols<T> {
36    /// Multiplicity counter tracking the number of XOR operations requested for each triple
37    pub mult: T,
38}
39
40/// Columns for the preprocessed table of the XOR lookup
41#[repr(C)]
42#[derive(Copy, Clone, Debug, AlignedBorrow)]
43pub struct XorLookupPreprocessedCols<T> {
44    pub x: T,
45    pub y: T,
46    /// XOR result (x ⊕ y)
47    pub z: T,
48}
49
50pub const NUM_XOR_LOOKUP_COLS: usize = size_of::<XorLookupCols<u8>>();
51pub const NUM_XOR_LOOKUP_PREPROCESSED_COLS: usize = size_of::<XorLookupPreprocessedCols<u8>>();
52
53/// Xor via preprocessed lookup table. Can only be used if inputs have less than approximately 10-bits.
54#[derive(Clone, Copy, Debug, derive_new::new)]
55pub struct XorLookupAir<const M: usize> {
56    pub bus: XorBus,
57}
58
59impl<F: Field, const M: usize> BaseAirWithPublicValues<F> for XorLookupAir<M> {}
60impl<F: Field, const M: usize> PartitionedBaseAir<F> for XorLookupAir<M> {}
61impl<F: Field, const M: usize> BaseAir<F> for XorLookupAir<M> {
62    fn width(&self) -> usize {
63        NUM_XOR_LOOKUP_COLS
64    }
65
66    /// Generates a preprocessed table with a row for each possible triple (x, y, x^y)
67    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
68        let rows: Vec<_> = (0..(1 << M) * (1 << M))
69            .flat_map(|i| {
70                let x = i / (1 << M);
71                let y = i % (1 << M);
72                let z = x ^ y;
73                [x, y, z].map(F::from_canonical_u32)
74            })
75            .collect();
76
77        Some(RowMajorMatrix::new(rows, NUM_XOR_LOOKUP_PREPROCESSED_COLS))
78    }
79}
80
81impl<AB, const M: usize> Air<AB> for XorLookupAir<M>
82where
83    AB: InteractionBuilder + PairBuilder,
84{
85    fn eval(&self, builder: &mut AB) {
86        let main = builder.main();
87        let preprocessed = builder.preprocessed();
88
89        let prep_local = preprocessed.row_slice(0);
90        let prep_local: &XorLookupPreprocessedCols<AB::Var> = (*prep_local).borrow();
91        let local = main.row_slice(0);
92        let local: &XorLookupCols<AB::Var> = (*local).borrow();
93
94        self.bus
95            .receive(prep_local.x, prep_local.y, prep_local.z)
96            .eval(builder, local.mult);
97    }
98}
99
100/// This chip gets requests to compute the xor of two numbers x and y of at most M bits.
101/// It generates a preprocessed table with a row for each possible triple (x, y, x^y)
102/// and keeps count of the number of times each triple is requested for the single main trace column.
103#[derive(Debug)]
104pub struct XorLookupChip<const M: usize> {
105    pub air: XorLookupAir<M>,
106    /// Tracks the count of each (x,y) pair requested
107    pub count: Vec<Vec<AtomicU32>>,
108}
109
110impl<const M: usize> XorLookupChip<M> {
111    pub fn new(bus: BusIndex) -> Self {
112        let mut count = vec![];
113        for _ in 0..(1 << M) {
114            let mut row = vec![];
115            for _ in 0..(1 << M) {
116                row.push(AtomicU32::new(0));
117            }
118            count.push(row);
119        }
120        Self {
121            air: XorLookupAir::new(XorBus(LookupBus::new(bus))),
122            count,
123        }
124    }
125
126    /// The xor bus this chip interacts with
127    pub fn bus(&self) -> XorBus {
128        self.air.bus
129    }
130
131    fn calc_xor(&self, x: u32, y: u32) -> u32 {
132        x ^ y
133    }
134
135    /// Request an XOR operation for inputs x and y
136    /// Increments the count for this (x,y) pair and returns x ⊕ y
137    pub fn request(&self, x: u32, y: u32) -> u32 {
138        let val_atomic = &self.count[x as usize][y as usize];
139        val_atomic.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
140
141        self.calc_xor(x, y)
142    }
143
144    /// Resets all request counters to zero
145    pub fn clear(&self) {
146        for i in 0..(1 << M) {
147            for j in 0..(1 << M) {
148                self.count[i][j].store(0, std::sync::atomic::Ordering::Relaxed);
149            }
150        }
151    }
152
153    /// Generates the multiplicity trace based on requests
154    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
155        debug_assert_eq!(self.count.len(), 1 << M);
156        let multiplicities: Vec<_> = self
157            .count
158            .iter()
159            .flat_map(|count_x| {
160                debug_assert_eq!(count_x.len(), 1 << M);
161                count_x
162                    .iter()
163                    .map(|count_xy| F::from_canonical_u32(count_xy.load(atomic::Ordering::SeqCst)))
164            })
165            .collect();
166
167        RowMajorMatrix::new_col(multiplicities)
168    }
169}
170
171impl<SC: StarkGenericConfig, const M: usize> Chip<SC> for XorLookupChip<M> {
172    fn air(&self) -> AirRef<SC> {
173        Arc::new(self.air)
174    }
175
176    fn generate_air_proof_input(self) -> AirProofInput<SC> {
177        let trace = self.generate_trace::<Val<SC>>();
178        AirProofInput::simple_no_pis(trace)
179    }
180}
181
182impl<const M: usize> ChipUsageGetter for XorLookupChip<M> {
183    fn air_name(&self) -> String {
184        get_air_name(&self.air)
185    }
186
187    fn current_trace_height(&self) -> usize {
188        1 << (2 * M)
189    }
190
191    fn trace_width(&self) -> usize {
192        NUM_XOR_LOOKUP_COLS
193    }
194}