openvm_circuit_primitives/range_gate/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
//! Range check for a fixed bit size without using preprocessed trace.
//!
//! Caution: We almost always prefer to use the [VariableRangeCheckerChip](super::var_range::VariableRangeCheckerChip) instead of this chip.

use std::{
    borrow::Borrow,
    mem::{size_of, transmute},
    sync::atomic::AtomicU32,
};

use openvm_circuit_primitives_derive::AlignedBorrow;
use openvm_stark_backend::{
    interaction::InteractionBuilder,
    p3_air::{Air, AirBuilder, BaseAir},
    p3_field::{AbstractField, Field},
    p3_matrix::{dense::RowMajorMatrix, Matrix},
    p3_util::indices_arr,
    rap::{BaseAirWithPublicValues, PartitionedBaseAir},
};

pub use crate::range::RangeCheckBus;

#[cfg(test)]
mod tests;

#[derive(Copy, Clone, Default, AlignedBorrow)]
pub struct RangeGateCols<T> {
    pub counter: T,
    pub mult: T,
}

impl<T: Clone> RangeGateCols<T> {
    pub fn from_slice(slice: &[T]) -> Self {
        let counter = slice[0].clone();
        let mult = slice[1].clone();

        Self { counter, mult }
    }
}

pub const NUM_RANGE_GATE_COLS: usize = size_of::<RangeGateCols<u8>>();
pub const RANGE_GATE_COL_MAP: RangeGateCols<usize> = make_col_map();

#[derive(Clone, Copy, Debug, derive_new::new)]
pub struct RangeCheckerGateAir {
    pub bus: RangeCheckBus,
}

impl<F: Field> BaseAirWithPublicValues<F> for RangeCheckerGateAir {}
impl<F: Field> PartitionedBaseAir<F> for RangeCheckerGateAir {}
impl<F: Field> BaseAir<F> for RangeCheckerGateAir {
    fn width(&self) -> usize {
        NUM_RANGE_GATE_COLS
    }
}

impl<AB: InteractionBuilder> Air<AB> for RangeCheckerGateAir {
    fn eval(&self, builder: &mut AB) {
        let main = builder.main();

        let (local, next) = (main.row_slice(0), main.row_slice(1));
        let local: &RangeGateCols<AB::Var> = (*local).borrow();
        let next: &RangeGateCols<AB::Var> = (*next).borrow();

        builder
            .when_first_row()
            .assert_eq(local.counter, AB::Expr::ZERO);
        builder
            .when_transition()
            .assert_eq(local.counter + AB::Expr::ONE, next.counter);
        // The trace height is not part of the vkey, so we must enforce it here.
        builder.when_last_row().assert_eq(
            local.counter,
            AB::F::from_canonical_u32(self.bus.range_max - 1),
        );
        // Omit creating separate bridge.rs file for brevity
        self.bus.receive(local.counter).eval(builder, local.mult);
    }
}

/// This chip gets requests to verify that a number is in the range
/// [0, MAX). In the trace, there is a counter column and a multiplicity
/// column. The counter column is generated using a gate, as opposed to
/// the other RangeCheckerChip.
#[derive(Debug)]
pub struct RangeCheckerGateChip {
    pub air: RangeCheckerGateAir,
    pub count: Vec<AtomicU32>,
}

impl RangeCheckerGateChip {
    pub fn new(bus: RangeCheckBus) -> Self {
        let count = (0..bus.range_max).map(|_| AtomicU32::new(0)).collect();

        Self {
            air: RangeCheckerGateAir::new(bus),
            count,
        }
    }

    pub fn bus(&self) -> RangeCheckBus {
        self.air.bus
    }

    pub fn bus_index(&self) -> usize {
        self.air.bus.index
    }

    pub fn range_max(&self) -> u32 {
        self.air.bus.range_max
    }

    pub fn air_width(&self) -> usize {
        2
    }

    pub fn add_count(&self, val: u32) {
        assert!(
            val < self.range_max(),
            "range exceeded: {} >= {}",
            val,
            self.range_max()
        );
        let val_atomic = &self.count[val as usize];
        val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
    }

    pub fn clear(&self) {
        for i in 0..self.count.len() {
            self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
        }
    }

    pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
        let rows = self
            .count
            .iter()
            .enumerate()
            .flat_map(|(i, count)| {
                let c = count.load(std::sync::atomic::Ordering::Relaxed);
                vec![F::from_canonical_usize(i), F::from_canonical_u32(c)]
            })
            .collect();
        RowMajorMatrix::new(rows, NUM_RANGE_GATE_COLS)
    }
}

const fn make_col_map() -> RangeGateCols<usize> {
    let indices_arr = indices_arr::<NUM_RANGE_GATE_COLS>();
    unsafe { transmute::<[usize; NUM_RANGE_GATE_COLS], RangeGateCols<usize>>(indices_arr) }
}