openvm_circuit_primitives/range_gate/
mod.rs

1//! Range check for a fixed bit size without using preprocessed trace.
2//!
3//! Caution: We almost always prefer to use the
4//! [VariableRangeCheckerChip](super::var_range::VariableRangeCheckerChip) instead of this chip.
5
6use std::{
7    borrow::Borrow,
8    mem::{size_of, transmute},
9    sync::atomic::AtomicU32,
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14    interaction::{BusIndex, InteractionBuilder},
15    p3_air::{Air, AirBuilder, BaseAir},
16    p3_field::{Field, PrimeCharacteristicRing},
17    p3_matrix::{dense::RowMajorMatrix, Matrix},
18    p3_util::indices_arr,
19    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
20};
21
22pub use crate::range::RangeCheckBus;
23
24#[cfg(test)]
25mod tests;
26
27#[repr(C)]
28#[derive(Copy, Clone, Default, AlignedBorrow)]
29pub struct RangeGateCols<T> {
30    /// Column with sequential values from 0 to range_max-1
31    pub counter: T,
32    /// Number of range checks requested for each value
33    pub mult: T,
34}
35
36impl<T: Clone> RangeGateCols<T> {
37    pub fn from_slice(slice: &[T]) -> Self {
38        let counter = slice[0].clone();
39        let mult = slice[1].clone();
40
41        Self { counter, mult }
42    }
43}
44
45pub const NUM_RANGE_GATE_COLS: usize = size_of::<RangeGateCols<u8>>();
46pub const RANGE_GATE_COL_MAP: RangeGateCols<usize> = make_col_map();
47
48#[derive(Clone, Copy, Debug, derive_new::new)]
49pub struct RangeCheckerGateAir {
50    pub bus: RangeCheckBus,
51}
52
53impl<F: Field> BaseAirWithPublicValues<F> for RangeCheckerGateAir {}
54impl<F: Field> PartitionedBaseAir<F> for RangeCheckerGateAir {}
55impl<F: Field> BaseAir<F> for RangeCheckerGateAir {
56    fn width(&self) -> usize {
57        NUM_RANGE_GATE_COLS
58    }
59}
60
61impl<AB: InteractionBuilder> Air<AB> for RangeCheckerGateAir {
62    fn eval(&self, builder: &mut AB) {
63        let main = builder.main();
64
65        let (local, next) = (
66            main.row_slice(0).expect("window should have two elements"),
67            main.row_slice(1).expect("window should have two elements"),
68        );
69        let local: &RangeGateCols<AB::Var> = (*local).borrow();
70        let next: &RangeGateCols<AB::Var> = (*next).borrow();
71
72        // Ensure counter starts at 0
73        builder
74            .when_first_row()
75            .assert_eq(local.counter, AB::Expr::ZERO);
76        // Ensure counter increments by 1 in each row
77        builder
78            .when_transition()
79            .assert_eq(local.counter + AB::Expr::ONE, next.counter);
80        // Constrain the last counter value to ensure trace height equals range_max
81        // This is critical as the trace height is not part of the verification key
82        builder
83            .when_last_row()
84            .assert_eq(local.counter, AB::F::from_u32(self.bus.range_max - 1));
85        // Omit creating separate bridge.rs file for brevity
86        self.bus.receive(local.counter).eval(builder, local.mult);
87    }
88}
89
90/// This chip gets requests to verify that a number is in the range
91/// [0, MAX). In the trace, there is a counter column and a multiplicity
92/// column. The counter column is generated using a gate, as opposed to
93/// the other RangeCheckerChip.
94pub struct RangeCheckerGateChip {
95    pub air: RangeCheckerGateAir,
96    pub count: Vec<AtomicU32>,
97}
98
99impl RangeCheckerGateChip {
100    pub fn new(bus: RangeCheckBus) -> Self {
101        let count = (0..bus.range_max).map(|_| AtomicU32::new(0)).collect();
102
103        Self {
104            air: RangeCheckerGateAir::new(bus),
105            count,
106        }
107    }
108
109    pub fn bus(&self) -> RangeCheckBus {
110        self.air.bus
111    }
112
113    pub fn bus_index(&self) -> BusIndex {
114        self.air.bus.inner.index
115    }
116
117    pub fn range_max(&self) -> u32 {
118        self.air.bus.range_max
119    }
120
121    pub fn air_width(&self) -> usize {
122        2
123    }
124
125    pub fn add_count(&self, val: u32) {
126        assert!(
127            val < self.range_max(),
128            "range exceeded: {} >= {}",
129            val,
130            self.range_max()
131        );
132        let val_atomic = &self.count[val as usize];
133        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
134    }
135
136    pub fn clear(&self) {
137        for i in 0..self.count.len() {
138            self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
139        }
140    }
141
142    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
143        let rows = self
144            .count
145            .iter()
146            .enumerate()
147            .flat_map(|(i, count)| {
148                let c = count.swap(0, std::sync::atomic::Ordering::Relaxed);
149                vec![F::from_usize(i), F::from_u32(c)]
150            })
151            .collect();
152        RowMajorMatrix::new(rows, NUM_RANGE_GATE_COLS)
153    }
154}
155
156const fn make_col_map() -> RangeGateCols<usize> {
157    let indices_arr = indices_arr::<NUM_RANGE_GATE_COLS>();
158    // SAFETY: RangeGateCols is repr(C) with two fields, same layout as [usize; 2].
159    // NUM_RANGE_GATE_COLS equals 2. Transmute reinterprets array as struct.
160    unsafe { transmute::<[usize; NUM_RANGE_GATE_COLS], RangeGateCols<usize>>(indices_arr) }
161}