openvm_circuit_primitives/range_gate/
mod.rs1use std::{
7 borrow::Borrow,
8 mem::{size_of, transmute},
9 sync::atomic::AtomicU32,
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14 interaction::{BusIndex, InteractionBuilder},
15 p3_air::{Air, AirBuilder, BaseAir},
16 p3_field::{Field, FieldAlgebra},
17 p3_matrix::{dense::RowMajorMatrix, Matrix},
18 p3_util::indices_arr,
19 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
20};
21
22pub use crate::range::RangeCheckBus;
23
24#[cfg(test)]
25mod tests;
26
27#[repr(C)]
28#[derive(Copy, Clone, Default, AlignedBorrow)]
29pub struct RangeGateCols<T> {
30 pub counter: T,
32 pub mult: T,
34}
35
36impl<T: Clone> RangeGateCols<T> {
37 pub fn from_slice(slice: &[T]) -> Self {
38 let counter = slice[0].clone();
39 let mult = slice[1].clone();
40
41 Self { counter, mult }
42 }
43}
44
45pub const NUM_RANGE_GATE_COLS: usize = size_of::<RangeGateCols<u8>>();
46pub const RANGE_GATE_COL_MAP: RangeGateCols<usize> = make_col_map();
47
48#[derive(Clone, Copy, Debug, derive_new::new)]
49pub struct RangeCheckerGateAir {
50 pub bus: RangeCheckBus,
51}
52
53impl<F: Field> BaseAirWithPublicValues<F> for RangeCheckerGateAir {}
54impl<F: Field> PartitionedBaseAir<F> for RangeCheckerGateAir {}
55impl<F: Field> BaseAir<F> for RangeCheckerGateAir {
56 fn width(&self) -> usize {
57 NUM_RANGE_GATE_COLS
58 }
59}
60
61impl<AB: InteractionBuilder> Air<AB> for RangeCheckerGateAir {
62 fn eval(&self, builder: &mut AB) {
63 let main = builder.main();
64
65 let (local, next) = (main.row_slice(0), main.row_slice(1));
66 let local: &RangeGateCols<AB::Var> = (*local).borrow();
67 let next: &RangeGateCols<AB::Var> = (*next).borrow();
68
69 builder
71 .when_first_row()
72 .assert_eq(local.counter, AB::Expr::ZERO);
73 builder
75 .when_transition()
76 .assert_eq(local.counter + AB::Expr::ONE, next.counter);
77 builder.when_last_row().assert_eq(
80 local.counter,
81 AB::F::from_canonical_u32(self.bus.range_max - 1),
82 );
83 self.bus.receive(local.counter).eval(builder, local.mult);
85 }
86}
87
88pub struct RangeCheckerGateChip {
93 pub air: RangeCheckerGateAir,
94 pub count: Vec<AtomicU32>,
95}
96
97impl RangeCheckerGateChip {
98 pub fn new(bus: RangeCheckBus) -> Self {
99 let count = (0..bus.range_max).map(|_| AtomicU32::new(0)).collect();
100
101 Self {
102 air: RangeCheckerGateAir::new(bus),
103 count,
104 }
105 }
106
107 pub fn bus(&self) -> RangeCheckBus {
108 self.air.bus
109 }
110
111 pub fn bus_index(&self) -> BusIndex {
112 self.air.bus.inner.index
113 }
114
115 pub fn range_max(&self) -> u32 {
116 self.air.bus.range_max
117 }
118
119 pub fn air_width(&self) -> usize {
120 2
121 }
122
123 pub fn add_count(&self, val: u32) {
124 assert!(
125 val < self.range_max(),
126 "range exceeded: {} >= {}",
127 val,
128 self.range_max()
129 );
130 let val_atomic = &self.count[val as usize];
131 val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
132 }
133
134 pub fn clear(&self) {
135 for i in 0..self.count.len() {
136 self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
137 }
138 }
139
140 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
141 let rows = self
142 .count
143 .iter()
144 .enumerate()
145 .flat_map(|(i, count)| {
146 let c = count.swap(0, std::sync::atomic::Ordering::Relaxed);
147 vec![F::from_canonical_usize(i), F::from_canonical_u32(c)]
148 })
149 .collect();
150 RowMajorMatrix::new(rows, NUM_RANGE_GATE_COLS)
151 }
152}
153
154const fn make_col_map() -> RangeGateCols<usize> {
155 let indices_arr = indices_arr::<NUM_RANGE_GATE_COLS>();
156 unsafe { transmute::<[usize; NUM_RANGE_GATE_COLS], RangeGateCols<usize>>(indices_arr) }
159}