openvm_circuit_primitives/range_gate/
mod.rs

1//! Range check for a fixed bit size without using preprocessed trace.
2//!
3//! Caution: We almost always prefer to use the [VariableRangeCheckerChip](super::var_range::VariableRangeCheckerChip) instead of this chip.
4
5use std::{
6    borrow::Borrow,
7    mem::{size_of, transmute},
8    sync::atomic::AtomicU32,
9};
10
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_stark_backend::{
13    interaction::{BusIndex, InteractionBuilder},
14    p3_air::{Air, AirBuilder, BaseAir},
15    p3_field::{Field, FieldAlgebra},
16    p3_matrix::{dense::RowMajorMatrix, Matrix},
17    p3_util::indices_arr,
18    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
19};
20
21pub use crate::range::RangeCheckBus;
22
23#[cfg(test)]
24mod tests;
25
26#[repr(C)]
27#[derive(Copy, Clone, Default, AlignedBorrow)]
28pub struct RangeGateCols<T> {
29    /// Column with sequential values from 0 to range_max-1
30    pub counter: T,
31    /// Number of range checks requested for each value
32    pub mult: T,
33}
34
35impl<T: Clone> RangeGateCols<T> {
36    pub fn from_slice(slice: &[T]) -> Self {
37        let counter = slice[0].clone();
38        let mult = slice[1].clone();
39
40        Self { counter, mult }
41    }
42}
43
44pub const NUM_RANGE_GATE_COLS: usize = size_of::<RangeGateCols<u8>>();
45pub const RANGE_GATE_COL_MAP: RangeGateCols<usize> = make_col_map();
46
47#[derive(Clone, Copy, Debug, derive_new::new)]
48pub struct RangeCheckerGateAir {
49    pub bus: RangeCheckBus,
50}
51
52impl<F: Field> BaseAirWithPublicValues<F> for RangeCheckerGateAir {}
53impl<F: Field> PartitionedBaseAir<F> for RangeCheckerGateAir {}
54impl<F: Field> BaseAir<F> for RangeCheckerGateAir {
55    fn width(&self) -> usize {
56        NUM_RANGE_GATE_COLS
57    }
58}
59
60impl<AB: InteractionBuilder> Air<AB> for RangeCheckerGateAir {
61    fn eval(&self, builder: &mut AB) {
62        let main = builder.main();
63
64        let (local, next) = (main.row_slice(0), main.row_slice(1));
65        let local: &RangeGateCols<AB::Var> = (*local).borrow();
66        let next: &RangeGateCols<AB::Var> = (*next).borrow();
67
68        // Ensure counter starts at 0
69        builder
70            .when_first_row()
71            .assert_eq(local.counter, AB::Expr::ZERO);
72        // Ensure counter increments by 1 in each row
73        builder
74            .when_transition()
75            .assert_eq(local.counter + AB::Expr::ONE, next.counter);
76        // Constrain the last counter value to ensure trace height equals range_max
77        // This is critical as the trace height is not part of the verification key
78        builder.when_last_row().assert_eq(
79            local.counter,
80            AB::F::from_canonical_u32(self.bus.range_max - 1),
81        );
82        // Omit creating separate bridge.rs file for brevity
83        self.bus.receive(local.counter).eval(builder, local.mult);
84    }
85}
86
87/// This chip gets requests to verify that a number is in the range
88/// [0, MAX). In the trace, there is a counter column and a multiplicity
89/// column. The counter column is generated using a gate, as opposed to
90/// the other RangeCheckerChip.
91pub struct RangeCheckerGateChip {
92    pub air: RangeCheckerGateAir,
93    pub count: Vec<AtomicU32>,
94}
95
96impl RangeCheckerGateChip {
97    pub fn new(bus: RangeCheckBus) -> Self {
98        let count = (0..bus.range_max).map(|_| AtomicU32::new(0)).collect();
99
100        Self {
101            air: RangeCheckerGateAir::new(bus),
102            count,
103        }
104    }
105
106    pub fn bus(&self) -> RangeCheckBus {
107        self.air.bus
108    }
109
110    pub fn bus_index(&self) -> BusIndex {
111        self.air.bus.inner.index
112    }
113
114    pub fn range_max(&self) -> u32 {
115        self.air.bus.range_max
116    }
117
118    pub fn air_width(&self) -> usize {
119        2
120    }
121
122    pub fn add_count(&self, val: u32) {
123        assert!(
124            val < self.range_max(),
125            "range exceeded: {} >= {}",
126            val,
127            self.range_max()
128        );
129        let val_atomic = &self.count[val as usize];
130        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
131    }
132
133    pub fn clear(&self) {
134        for i in 0..self.count.len() {
135            self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
136        }
137    }
138
139    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
140        let rows = self
141            .count
142            .iter()
143            .enumerate()
144            .flat_map(|(i, count)| {
145                let c = count.load(std::sync::atomic::Ordering::Relaxed);
146                vec![F::from_canonical_usize(i), F::from_canonical_u32(c)]
147            })
148            .collect();
149        RowMajorMatrix::new(rows, NUM_RANGE_GATE_COLS)
150    }
151}
152
153const fn make_col_map() -> RangeGateCols<usize> {
154    let indices_arr = indices_arr::<NUM_RANGE_GATE_COLS>();
155    unsafe { transmute::<[usize; NUM_RANGE_GATE_COLS], RangeGateCols<usize>>(indices_arr) }
156}