openvm_circuit_primitives/range_tuple/
mod.rs

1//! Range check a tuple simultaneously.
2//! When you know you want to range check `(x, y)` to `x_bits, y_bits` respectively
3//! and `2^{x_bits + y_bits} < ~2^20`, then you can use this chip to do the range check in one
4//! interaction versus the two interactions necessary if you were to use
5//! [VariableRangeCheckerChip](super::var_range::VariableRangeCheckerChip) instead.
6
7use std::{
8    mem::size_of,
9    sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14    config::{StarkGenericConfig, Val},
15    interaction::InteractionBuilder,
16    p3_air::{Air, BaseAir, PairBuilder},
17    p3_field::{Field, PrimeField32},
18    p3_matrix::{dense::RowMajorMatrix, Matrix},
19    prover::{cpu::CpuBackend, types::AirProvingContext},
20    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21    Chip, ChipUsageGetter,
22};
23
24mod bus;
25pub use bus::*;
26
27#[cfg(feature = "cuda")]
28mod cuda;
29#[cfg(feature = "cuda")]
30pub use cuda::*;
31
32#[cfg(test)]
33pub mod tests;
34
35#[repr(C)]
36#[derive(Default, Copy, Clone, AlignedBorrow)]
37pub struct RangeTupleCols<T> {
38    /// Number of range checks requested for each tuple combination
39    pub mult: T,
40}
41
42#[derive(Default, Clone)]
43pub struct RangeTuplePreprocessedCols<T> {
44    /// Contains all possible tuple combinations within specified ranges
45    pub tuple: Vec<T>,
46}
47
48pub const NUM_RANGE_TUPLE_COLS: usize = size_of::<RangeTupleCols<u8>>();
49
50#[derive(Clone, Copy, Debug)]
51pub struct RangeTupleCheckerAir<const N: usize> {
52    pub bus: RangeTupleCheckerBus<N>,
53}
54
55impl<const N: usize> RangeTupleCheckerAir<N> {
56    pub fn height(&self) -> u32 {
57        self.bus.sizes.iter().product()
58    }
59}
60impl<F: Field, const N: usize> BaseAirWithPublicValues<F> for RangeTupleCheckerAir<N> {}
61impl<F: Field, const N: usize> PartitionedBaseAir<F> for RangeTupleCheckerAir<N> {}
62
63impl<F: Field, const N: usize> BaseAir<F> for RangeTupleCheckerAir<N> {
64    fn width(&self) -> usize {
65        NUM_RANGE_TUPLE_COLS
66    }
67
68    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
69        let mut unrolled_matrix = Vec::with_capacity((self.height() as usize) * N);
70        let mut row = [0u32; N];
71        for _ in 0..self.height() {
72            unrolled_matrix.extend(row);
73            for i in (0..N).rev() {
74                if row[i] < self.bus.sizes[i] - 1 {
75                    row[i] += 1;
76                    break;
77                }
78                row[i] = 0;
79            }
80        }
81        Some(RowMajorMatrix::new(
82            unrolled_matrix
83                .iter()
84                .map(|&v| F::from_canonical_u32(v))
85                .collect(),
86            N,
87        ))
88    }
89}
90
91impl<AB: InteractionBuilder + PairBuilder, const N: usize> Air<AB> for RangeTupleCheckerAir<N> {
92    fn eval(&self, builder: &mut AB) {
93        let preprocessed = builder.preprocessed();
94        let prep_local = preprocessed.row_slice(0);
95        let prep_local = RangeTuplePreprocessedCols {
96            tuple: (*prep_local).to_vec(),
97        };
98        let main = builder.main();
99        let local = main.row_slice(0);
100        let local = RangeTupleCols { mult: (*local)[0] };
101
102        self.bus.receive(prep_local.tuple).eval(builder, local.mult);
103    }
104}
105
106#[derive(Debug)]
107pub struct RangeTupleCheckerChip<const N: usize> {
108    pub air: RangeTupleCheckerAir<N>,
109    pub count: Vec<Arc<AtomicU32>>,
110}
111
112pub type SharedRangeTupleCheckerChip<const N: usize> = Arc<RangeTupleCheckerChip<N>>;
113
114impl<const N: usize> RangeTupleCheckerChip<N> {
115    pub fn new(bus: RangeTupleCheckerBus<N>) -> Self {
116        let range_max = bus.sizes.iter().product();
117        let count = (0..range_max)
118            .map(|_| Arc::new(AtomicU32::new(0)))
119            .collect();
120
121        Self {
122            air: RangeTupleCheckerAir { bus },
123            count,
124        }
125    }
126
127    pub fn bus(&self) -> &RangeTupleCheckerBus<N> {
128        &self.air.bus
129    }
130
131    pub fn sizes(&self) -> &[u32; N] {
132        &self.air.bus.sizes
133    }
134
135    pub fn add_count(&self, ids: &[u32]) {
136        let index = ids
137            .iter()
138            .zip(self.air.bus.sizes.iter())
139            .fold(0, |acc, (id, sz)| acc * sz + id) as usize;
140        assert!(
141            index < self.count.len(),
142            "range exceeded: {} >= {}",
143            index,
144            self.count.len()
145        );
146        let val_atomic = &self.count[index];
147        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
148    }
149
150    pub fn clear(&self) {
151        for val in &self.count {
152            val.store(0, std::sync::atomic::Ordering::Relaxed);
153        }
154    }
155
156    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
157        let rows = self
158            .count
159            .iter()
160            .map(|c| F::from_canonical_u32(c.swap(0, std::sync::atomic::Ordering::Relaxed)))
161            .collect::<Vec<_>>();
162        RowMajorMatrix::new(rows, 1)
163    }
164}
165
166impl<R, SC: StarkGenericConfig, const N: usize> Chip<R, CpuBackend<SC>> for RangeTupleCheckerChip<N>
167where
168    Val<SC>: PrimeField32,
169{
170    fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
171        let trace = self.generate_trace::<Val<SC>>();
172        AirProvingContext::simple_no_pis(Arc::new(trace))
173    }
174}
175
176impl<const N: usize> ChipUsageGetter for RangeTupleCheckerChip<N> {
177    fn air_name(&self) -> String {
178        get_air_name(&self.air)
179    }
180    fn constant_trace_height(&self) -> Option<usize> {
181        Some(self.count.len())
182    }
183    fn current_trace_height(&self) -> usize {
184        self.count.len()
185    }
186    fn trace_width(&self) -> usize {
187        NUM_RANGE_TUPLE_COLS
188    }
189}