openvm_circuit_primitives/var_range/
mod.rs

1//! A chip which uses preprocessed trace to provide a lookup table for range checking
2//! a variable `x` has `b` bits where `b` can be any integer in `[0, range_max_bits]`.
3//! In other words, the same chip can be used to range check for different bit sizes.
4//! We define `0` to have `0` bits.
5
6use core::mem::size_of;
7use std::{
8    borrow::{Borrow, BorrowMut},
9    sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14    config::{StarkGenericConfig, Val},
15    interaction::InteractionBuilder,
16    p3_air::{Air, BaseAir, PairBuilder},
17    p3_field::{Field, PrimeField32},
18    p3_matrix::{dense::RowMajorMatrix, Matrix},
19    prover::types::AirProofInput,
20    rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21    AirRef, Chip, ChipUsageGetter,
22};
23use tracing::instrument;
24
25mod bus;
26#[cfg(test)]
27pub mod tests;
28
29pub use bus::*;
30
31#[derive(Default, AlignedBorrow, Copy, Clone)]
32#[repr(C)]
33pub struct VariableRangeCols<T> {
34    /// Number of range checks requested for each (value, max_bits) pair
35    pub mult: T,
36}
37
38#[derive(Default, AlignedBorrow, Copy, Clone)]
39#[repr(C)]
40pub struct VariableRangePreprocessedCols<T> {
41    /// The value being range checked
42    pub value: T,
43    /// The maximum number of bits for this value
44    pub max_bits: T,
45}
46
47pub const NUM_VARIABLE_RANGE_COLS: usize = size_of::<VariableRangeCols<u8>>();
48pub const NUM_VARIABLE_RANGE_PREPROCESSED_COLS: usize =
49    size_of::<VariableRangePreprocessedCols<u8>>();
50
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct VariableRangeCheckerAir {
53    pub bus: VariableRangeCheckerBus,
54}
55
56impl VariableRangeCheckerAir {
57    pub fn range_max_bits(&self) -> usize {
58        self.bus.range_max_bits
59    }
60}
61
62impl<F: Field> BaseAirWithPublicValues<F> for VariableRangeCheckerAir {}
63impl<F: Field> PartitionedBaseAir<F> for VariableRangeCheckerAir {}
64impl<F: Field> BaseAir<F> for VariableRangeCheckerAir {
65    fn width(&self) -> usize {
66        NUM_VARIABLE_RANGE_COLS
67    }
68
69    fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
70        let rows: Vec<F> = [F::ZERO; NUM_VARIABLE_RANGE_PREPROCESSED_COLS]
71            .into_iter()
72            .chain((0..=self.range_max_bits()).flat_map(|bits| {
73                (0..(1 << bits)).flat_map(move |value| {
74                    [F::from_canonical_u32(value), F::from_canonical_usize(bits)].into_iter()
75                })
76            }))
77            .collect();
78        Some(RowMajorMatrix::new(
79            rows,
80            NUM_VARIABLE_RANGE_PREPROCESSED_COLS,
81        ))
82    }
83}
84
85impl<AB: InteractionBuilder + PairBuilder> Air<AB> for VariableRangeCheckerAir {
86    fn eval(&self, builder: &mut AB) {
87        let preprocessed = builder.preprocessed();
88        let prep_local = preprocessed.row_slice(0);
89        let prep_local: &VariableRangePreprocessedCols<AB::Var> = (*prep_local).borrow();
90        let main = builder.main();
91        let local = main.row_slice(0);
92        let local: &VariableRangeCols<AB::Var> = (*local).borrow();
93        // Omit creating separate bridge.rs file for brevity
94        self.bus
95            .receive(prep_local.value, prep_local.max_bits)
96            .eval(builder, local.mult);
97    }
98}
99
100pub struct VariableRangeCheckerChip {
101    pub air: VariableRangeCheckerAir,
102    pub count: Vec<AtomicU32>,
103}
104
105#[derive(Clone)]
106pub struct SharedVariableRangeCheckerChip(Arc<VariableRangeCheckerChip>);
107
108impl VariableRangeCheckerChip {
109    pub fn new(bus: VariableRangeCheckerBus) -> Self {
110        let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
111        let count = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
112        Self {
113            air: VariableRangeCheckerAir::new(bus),
114            count,
115        }
116    }
117
118    pub fn bus(&self) -> VariableRangeCheckerBus {
119        self.air.bus
120    }
121
122    pub fn range_max_bits(&self) -> usize {
123        self.air.range_max_bits()
124    }
125
126    pub fn air_width(&self) -> usize {
127        NUM_VARIABLE_RANGE_COLS
128    }
129
130    #[instrument(
131        name = "VariableRangeCheckerChip::add_count",
132        skip(self),
133        level = "trace"
134    )]
135    pub fn add_count(&self, value: u32, max_bits: usize) {
136        // index is 2^max_bits + value - 1 + 1 for the extra [0, 0] row
137        // if each [value, max_bits] is valid, the sends multiset will be exactly the receives multiset
138        let idx = (1 << max_bits) + (value as usize);
139        assert!(
140            idx < self.count.len(),
141            "range exceeded: {} >= {}",
142            idx,
143            self.count.len()
144        );
145        let val_atomic = &self.count[idx];
146        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147    }
148
149    pub fn clear(&self) {
150        for i in 0..self.count.len() {
151            self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
152        }
153    }
154
155    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
156        let mut rows = F::zero_vec(self.count.len() * NUM_VARIABLE_RANGE_COLS);
157        for (n, row) in rows.chunks_mut(NUM_VARIABLE_RANGE_COLS).enumerate() {
158            let cols: &mut VariableRangeCols<F> = row.borrow_mut();
159            cols.mult =
160                F::from_canonical_u32(self.count[n].load(std::sync::atomic::Ordering::SeqCst));
161        }
162        RowMajorMatrix::new(rows, NUM_VARIABLE_RANGE_COLS)
163    }
164
165    /// Range checks that `value` is `bits` bits by decomposing into `limbs` where all but
166    /// last limb is `range_max_bits` bits. Assumes there are enough limbs.
167    pub fn decompose<F: Field>(&self, mut value: u32, bits: usize, limbs: &mut [F]) {
168        debug_assert!(
169            limbs.len() >= bits.div_ceil(self.range_max_bits()),
170            "Not enough limbs: len {}",
171            limbs.len()
172        );
173        let mask = (1 << self.range_max_bits()) - 1;
174        let mut bits_remaining = bits;
175        for limb in limbs.iter_mut() {
176            let limb_u32 = value & mask;
177            *limb = F::from_canonical_u32(limb_u32);
178            self.add_count(limb_u32, bits_remaining.min(self.range_max_bits()));
179
180            value >>= self.range_max_bits();
181            bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
182        }
183        debug_assert_eq!(value, 0);
184        debug_assert_eq!(bits_remaining, 0);
185    }
186}
187
188impl SharedVariableRangeCheckerChip {
189    pub fn new(bus: VariableRangeCheckerBus) -> Self {
190        Self(Arc::new(VariableRangeCheckerChip::new(bus)))
191    }
192
193    pub fn bus(&self) -> VariableRangeCheckerBus {
194        self.0.bus()
195    }
196
197    pub fn range_max_bits(&self) -> usize {
198        self.0.range_max_bits()
199    }
200
201    pub fn air_width(&self) -> usize {
202        self.0.air_width()
203    }
204
205    pub fn add_count(&self, value: u32, max_bits: usize) {
206        self.0.add_count(value, max_bits)
207    }
208
209    pub fn clear(&self) {
210        self.0.clear()
211    }
212
213    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
214        self.0.generate_trace()
215    }
216}
217
218impl<SC: StarkGenericConfig> Chip<SC> for VariableRangeCheckerChip
219where
220    Val<SC>: PrimeField32,
221{
222    fn air(&self) -> AirRef<SC> {
223        Arc::new(self.air)
224    }
225
226    fn generate_air_proof_input(self) -> AirProofInput<SC> {
227        let trace = self.generate_trace::<Val<SC>>();
228        AirProofInput::simple_no_pis(trace)
229    }
230}
231
232impl<SC: StarkGenericConfig> Chip<SC> for SharedVariableRangeCheckerChip
233where
234    Val<SC>: PrimeField32,
235{
236    fn air(&self) -> AirRef<SC> {
237        self.0.air()
238    }
239
240    fn generate_air_proof_input(self) -> AirProofInput<SC> {
241        self.0.generate_air_proof_input()
242    }
243}
244
245impl ChipUsageGetter for VariableRangeCheckerChip {
246    fn air_name(&self) -> String {
247        get_air_name(&self.air)
248    }
249    fn constant_trace_height(&self) -> Option<usize> {
250        Some(self.count.len())
251    }
252    fn current_trace_height(&self) -> usize {
253        self.count.len()
254    }
255    fn trace_width(&self) -> usize {
256        NUM_VARIABLE_RANGE_COLS
257    }
258}
259
260impl ChipUsageGetter for SharedVariableRangeCheckerChip {
261    fn air_name(&self) -> String {
262        self.0.air_name()
263    }
264
265    fn constant_trace_height(&self) -> Option<usize> {
266        self.0.constant_trace_height()
267    }
268
269    fn current_trace_height(&self) -> usize {
270        self.0.current_trace_height()
271    }
272
273    fn trace_width(&self) -> usize {
274        self.0.trace_width()
275    }
276}
277
278impl AsRef<VariableRangeCheckerChip> for SharedVariableRangeCheckerChip {
279    fn as_ref(&self) -> &VariableRangeCheckerChip {
280        &self.0
281    }
282}