openvm_circuit_primitives/range_tuple/
mod.rs

1//! Range check a tuple simultaneously.
2//! When you know you want to range check `(x, y)` to `x_bits, y_bits` respectively
3//! and `2^{x_bits + y_bits} < ~2^20`, then you can use this chip to do the range check in one interaction
4//! versus the two interactions necessary if you were to use [VariableRangeCheckerChip](super::var_range::VariableRangeCheckerChip) instead.
5
6use std::{
7    mem::size_of,
8    sync::{atomic::AtomicU32, Arc},
9};
10
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_stark_backend::{
13    config::{StarkGenericConfig, Val},
14    interaction::InteractionBuilder,
15    p3_air::{Air, BaseAir, PairBuilder},
16    p3_field::{Field, PrimeField32},
17    p3_matrix::{dense::RowMajorMatrix, Matrix},
18    prover::types::AirProofInput,
19    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
20    AirRef, Chip, ChipUsageGetter,
21};
22
23mod bus;
24
25#[cfg(test)]
26pub mod tests;
27
28pub use bus::*;
29
30#[repr(C)]
31#[derive(Default, Copy, Clone, AlignedBorrow)]
32pub struct RangeTupleCols<T> {
33    /// Number of range checks requested for each tuple combination
34    pub mult: T,
35}
36
37#[derive(Default, Clone)]
38pub struct RangeTuplePreprocessedCols<T> {
39    /// Contains all possible tuple combinations within specified ranges
40    pub tuple: Vec<T>,
41}
42
43pub const NUM_RANGE_TUPLE_COLS: usize = size_of::<RangeTupleCols<u8>>();
44
45#[derive(Clone, Copy, Debug)]
46pub struct RangeTupleCheckerAir<const N: usize> {
47    pub bus: RangeTupleCheckerBus<N>,
48}
49
50impl<const N: usize> RangeTupleCheckerAir<N> {
51    pub fn height(&self) -> u32 {
52        self.bus.sizes.iter().product()
53    }
54}
55impl<F: Field, const N: usize> BaseAirWithPublicValues<F> for RangeTupleCheckerAir<N> {}
56impl<F: Field, const N: usize> PartitionedBaseAir<F> for RangeTupleCheckerAir<N> {}
57
58impl<F: Field, const N: usize> BaseAir<F> for RangeTupleCheckerAir<N> {
59    fn width(&self) -> usize {
60        NUM_RANGE_TUPLE_COLS
61    }
62
63    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
64        let mut unrolled_matrix = Vec::with_capacity((self.height() as usize) * N);
65        let mut row = [0u32; N];
66        for _ in 0..self.height() {
67            unrolled_matrix.extend(row);
68            for i in (0..N).rev() {
69                if row[i] < self.bus.sizes[i] - 1 {
70                    row[i] += 1;
71                    break;
72                }
73                row[i] = 0;
74            }
75        }
76        Some(RowMajorMatrix::new(
77            unrolled_matrix
78                .iter()
79                .map(|&v| F::from_canonical_u32(v))
80                .collect(),
81            N,
82        ))
83    }
84}
85
86impl<AB: InteractionBuilder + PairBuilder, const N: usize> Air<AB> for RangeTupleCheckerAir<N> {
87    fn eval(&self, builder: &mut AB) {
88        let preprocessed = builder.preprocessed();
89        let prep_local = preprocessed.row_slice(0);
90        let prep_local = RangeTuplePreprocessedCols {
91            tuple: (*prep_local).to_vec(),
92        };
93        let main = builder.main();
94        let local = main.row_slice(0);
95        let local = RangeTupleCols { mult: (*local)[0] };
96
97        self.bus.receive(prep_local.tuple).eval(builder, local.mult);
98    }
99}
100
101#[derive(Debug)]
102pub struct RangeTupleCheckerChip<const N: usize> {
103    pub air: RangeTupleCheckerAir<N>,
104    pub count: Vec<Arc<AtomicU32>>,
105}
106
107#[derive(Debug, Clone)]
108pub struct SharedRangeTupleCheckerChip<const N: usize>(Arc<RangeTupleCheckerChip<N>>);
109
110impl<const N: usize> RangeTupleCheckerChip<N> {
111    pub fn new(bus: RangeTupleCheckerBus<N>) -> Self {
112        let range_max = bus.sizes.iter().product();
113        let count = (0..range_max)
114            .map(|_| Arc::new(AtomicU32::new(0)))
115            .collect();
116
117        Self {
118            air: RangeTupleCheckerAir { bus },
119            count,
120        }
121    }
122
123    pub fn bus(&self) -> &RangeTupleCheckerBus<N> {
124        &self.air.bus
125    }
126
127    pub fn sizes(&self) -> &[u32; N] {
128        &self.air.bus.sizes
129    }
130
131    pub fn add_count(&self, ids: &[u32]) {
132        let index = ids
133            .iter()
134            .zip(self.air.bus.sizes.iter())
135            .fold(0, |acc, (id, sz)| acc * sz + id) as usize;
136        assert!(
137            index < self.count.len(),
138            "range exceeded: {} >= {}",
139            index,
140            self.count.len()
141        );
142        let val_atomic = &self.count[index];
143        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
144    }
145
146    pub fn clear(&self) {
147        for val in &self.count {
148            val.store(0, std::sync::atomic::Ordering::Relaxed);
149        }
150    }
151
152    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
153        let rows = self
154            .count
155            .iter()
156            .map(|c| F::from_canonical_u32(c.load(std::sync::atomic::Ordering::SeqCst)))
157            .collect::<Vec<_>>();
158        RowMajorMatrix::new(rows, 1)
159    }
160}
161
162impl<const N: usize> SharedRangeTupleCheckerChip<N> {
163    pub fn new(bus: RangeTupleCheckerBus<N>) -> Self {
164        Self(Arc::new(RangeTupleCheckerChip::new(bus)))
165    }
166    pub fn bus(&self) -> &RangeTupleCheckerBus<N> {
167        self.0.bus()
168    }
169
170    pub fn sizes(&self) -> &[u32; N] {
171        self.0.sizes()
172    }
173
174    pub fn add_count(&self, ids: &[u32]) {
175        self.0.add_count(ids);
176    }
177
178    pub fn clear(&self) {
179        self.0.clear();
180    }
181
182    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
183        self.0.generate_trace()
184    }
185}
186
187impl<SC: StarkGenericConfig, const N: usize> Chip<SC> for RangeTupleCheckerChip<N>
188where
189    Val<SC>: PrimeField32,
190{
191    fn air(&self) -> AirRef<SC> {
192        Arc::new(self.air)
193    }
194
195    fn generate_air_proof_input(self) -> AirProofInput<SC> {
196        let trace = self.generate_trace::<Val<SC>>();
197        AirProofInput::simple_no_pis(trace)
198    }
199}
200
201impl<SC: StarkGenericConfig, const N: usize> Chip<SC> for SharedRangeTupleCheckerChip<N>
202where
203    Val<SC>: PrimeField32,
204{
205    fn air(&self) -> AirRef<SC> {
206        self.0.air()
207    }
208
209    fn generate_air_proof_input(self) -> AirProofInput<SC> {
210        self.0.generate_air_proof_input()
211    }
212}
213
214impl<const N: usize> ChipUsageGetter for RangeTupleCheckerChip<N> {
215    fn air_name(&self) -> String {
216        get_air_name(&self.air)
217    }
218    fn constant_trace_height(&self) -> Option<usize> {
219        Some(self.count.len())
220    }
221    fn current_trace_height(&self) -> usize {
222        self.count.len()
223    }
224    fn trace_width(&self) -> usize {
225        NUM_RANGE_TUPLE_COLS
226    }
227}
228
229impl<const N: usize> ChipUsageGetter for SharedRangeTupleCheckerChip<N> {
230    fn air_name(&self) -> String {
231        self.0.air_name()
232    }
233
234    fn constant_trace_height(&self) -> Option<usize> {
235        self.0.constant_trace_height()
236    }
237
238    fn current_trace_height(&self) -> usize {
239        self.0.current_trace_height()
240    }
241
242    fn trace_width(&self) -> usize {
243        self.0.trace_width()
244    }
245}
246
247impl<const N: usize> AsRef<RangeTupleCheckerChip<N>> for SharedRangeTupleCheckerChip<N> {
248    fn as_ref(&self) -> &RangeTupleCheckerChip<N> {
249        &self.0
250    }
251}