openvm_circuit_primitives/var_range/
mod.rs

1//! A chip which uses preprocessed trace to provide a lookup table for range checking
2//! a variable `x` has `b` bits where `b` can be any integer in `[0, range_max_bits]`.
3//! In other words, the same chip can be used to range check for different bit sizes.
4//! We define `0` to have `0` bits.
5
6use core::mem::size_of;
7use std::{
8    borrow::{Borrow, BorrowMut},
9    sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14    config::{StarkGenericConfig, Val},
15    interaction::InteractionBuilder,
16    p3_air::{Air, BaseAir, PairBuilder},
17    p3_field::{Field, PrimeField32},
18    p3_matrix::{dense::RowMajorMatrix, Matrix},
19    prover::{cpu::CpuBackend, types::AirProvingContext},
20    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21    Chip, ChipUsageGetter,
22};
23use tracing::instrument;
24
25mod bus;
26pub use bus::*;
27
28#[cfg(feature = "cuda")]
29mod cuda;
30#[cfg(feature = "cuda")]
31pub use cuda::*;
32
33#[cfg(test)]
34pub mod tests;
35
36#[derive(Default, AlignedBorrow, Copy, Clone)]
37#[repr(C)]
38pub struct VariableRangeCols<T> {
39    /// Number of range checks requested for each (value, max_bits) pair
40    pub mult: T,
41}
42
43#[derive(Default, AlignedBorrow, Copy, Clone)]
44#[repr(C)]
45pub struct VariableRangePreprocessedCols<T> {
46    /// The value being range checked
47    pub value: T,
48    /// The maximum number of bits for this value
49    pub max_bits: T,
50}
51
52pub const NUM_VARIABLE_RANGE_COLS: usize = size_of::<VariableRangeCols<u8>>();
53pub const NUM_VARIABLE_RANGE_PREPROCESSED_COLS: usize =
54    size_of::<VariableRangePreprocessedCols<u8>>();
55
56#[derive(Clone, Copy, Debug, derive_new::new)]
57pub struct VariableRangeCheckerAir {
58    pub bus: VariableRangeCheckerBus,
59}
60
61impl VariableRangeCheckerAir {
62    pub fn range_max_bits(&self) -> usize {
63        self.bus.range_max_bits
64    }
65}
66
67impl<F: Field> BaseAirWithPublicValues<F> for VariableRangeCheckerAir {}
68impl<F: Field> PartitionedBaseAir<F> for VariableRangeCheckerAir {}
69impl<F: Field> BaseAir<F> for VariableRangeCheckerAir {
70    fn width(&self) -> usize {
71        NUM_VARIABLE_RANGE_COLS
72    }
73
74    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
75        let rows: Vec<F> = [F::ZERO; NUM_VARIABLE_RANGE_PREPROCESSED_COLS]
76            .into_iter()
77            .chain((0..=self.range_max_bits()).flat_map(|bits| {
78                (0..(1 << bits)).flat_map(move |value| {
79                    [F::from_canonical_u32(value), F::from_canonical_usize(bits)].into_iter()
80                })
81            }))
82            .collect();
83        Some(RowMajorMatrix::new(
84            rows,
85            NUM_VARIABLE_RANGE_PREPROCESSED_COLS,
86        ))
87    }
88}
89
90impl<AB: InteractionBuilder + PairBuilder> Air<AB> for VariableRangeCheckerAir {
91    fn eval(&self, builder: &mut AB) {
92        let preprocessed = builder.preprocessed();
93        let prep_local = preprocessed.row_slice(0);
94        let prep_local: &VariableRangePreprocessedCols<AB::Var> = (*prep_local).borrow();
95        let main = builder.main();
96        let local = main.row_slice(0);
97        let local: &VariableRangeCols<AB::Var> = (*local).borrow();
98        // Omit creating separate bridge.rs file for brevity
99        self.bus
100            .receive(prep_local.value, prep_local.max_bits)
101            .eval(builder, local.mult);
102    }
103}
104
105pub struct VariableRangeCheckerChip {
106    pub air: VariableRangeCheckerAir,
107    pub count: Vec<AtomicU32>,
108}
109
110pub type SharedVariableRangeCheckerChip = Arc<VariableRangeCheckerChip>;
111
112impl VariableRangeCheckerChip {
113    pub fn new(bus: VariableRangeCheckerBus) -> Self {
114        let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
115        let count = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
116        Self {
117            air: VariableRangeCheckerAir::new(bus),
118            count,
119        }
120    }
121
122    pub fn bus(&self) -> VariableRangeCheckerBus {
123        self.air.bus
124    }
125
126    pub fn range_max_bits(&self) -> usize {
127        self.air.range_max_bits()
128    }
129
130    pub fn air_width(&self) -> usize {
131        NUM_VARIABLE_RANGE_COLS
132    }
133
134    #[instrument(
135        name = "VariableRangeCheckerChip::add_count",
136        skip(self),
137        level = "trace"
138    )]
139    pub fn add_count(&self, value: u32, max_bits: usize) {
140        // index is 2^max_bits + value - 1 + 1 for the extra [0, 0] row
141        // if each [value, max_bits] is valid, the sends multiset will be exactly the receives
142        // multiset
143        let idx = (1 << max_bits) + (value as usize);
144        assert!(
145            idx < self.count.len(),
146            "range exceeded: {} >= {}",
147            idx,
148            self.count.len()
149        );
150        let val_atomic = &self.count[idx];
151        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
152    }
153
154    pub fn clear(&self) {
155        for i in 0..self.count.len() {
156            self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
157        }
158    }
159
160    /// Generates trace and resets the internal counters all to 0.
161    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
162        let mut rows = F::zero_vec(self.count.len() * NUM_VARIABLE_RANGE_COLS);
163        for (n, row) in rows.chunks_mut(NUM_VARIABLE_RANGE_COLS).enumerate() {
164            let cols: &mut VariableRangeCols<F> = row.borrow_mut();
165            cols.mult =
166                F::from_canonical_u32(self.count[n].swap(0, std::sync::atomic::Ordering::Relaxed));
167        }
168        RowMajorMatrix::new(rows, NUM_VARIABLE_RANGE_COLS)
169    }
170
171    /// Range checks that `value` is `bits` bits by decomposing into `limbs` where all but
172    /// last limb is `range_max_bits` bits. Assumes there are enough limbs.
173    pub fn decompose<F: Field>(&self, mut value: u32, bits: usize, limbs: &mut [F]) {
174        debug_assert!(
175            limbs.len() >= bits.div_ceil(self.range_max_bits()),
176            "Not enough limbs: len {}",
177            limbs.len()
178        );
179        let mask = (1 << self.range_max_bits()) - 1;
180        let mut bits_remaining = bits;
181        for limb in limbs.iter_mut() {
182            let limb_u32 = value & mask;
183            *limb = F::from_canonical_u32(limb_u32);
184            self.add_count(limb_u32, bits_remaining.min(self.range_max_bits()));
185
186            value >>= self.range_max_bits();
187            bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
188        }
189        debug_assert_eq!(value, 0);
190        debug_assert_eq!(bits_remaining, 0);
191    }
192}
193
194// We allow any `R` type so this can work with arbitrary record arenas.
195impl<R, SC: StarkGenericConfig> Chip<R, CpuBackend<SC>> for VariableRangeCheckerChip
196where
197    Val<SC>: PrimeField32,
198{
199    /// Generates trace and resets the internal counters all to 0.
200    fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
201        let trace = self.generate_trace::<Val<SC>>();
202        AirProvingContext::simple_no_pis(Arc::new(trace))
203    }
204}
205
206impl ChipUsageGetter for VariableRangeCheckerChip {
207    fn air_name(&self) -> String {
208        get_air_name(&self.air)
209    }
210    fn constant_trace_height(&self) -> Option<usize> {
211        Some(self.count.len())
212    }
213    fn current_trace_height(&self) -> usize {
214        self.count.len()
215    }
216    fn trace_width(&self) -> usize {
217        NUM_VARIABLE_RANGE_COLS
218    }
219}