openvm_circuit_primitives/var_range/
mod.rs1use core::mem::size_of;
7use std::{
8 borrow::{Borrow, BorrowMut},
9 sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14 config::{StarkGenericConfig, Val},
15 interaction::InteractionBuilder,
16 p3_air::{Air, BaseAir, PairBuilder},
17 p3_field::{Field, PrimeField32},
18 p3_matrix::{dense::RowMajorMatrix, Matrix},
19 prover::{cpu::CpuBackend, types::AirProvingContext},
20 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21 Chip, ChipUsageGetter,
22};
23use tracing::instrument;
24
25mod bus;
26pub use bus::*;
27
28#[cfg(feature = "cuda")]
29mod cuda;
30#[cfg(feature = "cuda")]
31pub use cuda::*;
32
33#[cfg(test)]
34pub mod tests;
35
36#[derive(Default, AlignedBorrow, Copy, Clone)]
37#[repr(C)]
38pub struct VariableRangeCols<T> {
39 pub mult: T,
41}
42
43#[derive(Default, AlignedBorrow, Copy, Clone)]
44#[repr(C)]
45pub struct VariableRangePreprocessedCols<T> {
46 pub value: T,
48 pub max_bits: T,
50}
51
52pub const NUM_VARIABLE_RANGE_COLS: usize = size_of::<VariableRangeCols<u8>>();
53pub const NUM_VARIABLE_RANGE_PREPROCESSED_COLS: usize =
54 size_of::<VariableRangePreprocessedCols<u8>>();
55
56#[derive(Clone, Copy, Debug, derive_new::new)]
57pub struct VariableRangeCheckerAir {
58 pub bus: VariableRangeCheckerBus,
59}
60
61impl VariableRangeCheckerAir {
62 pub fn range_max_bits(&self) -> usize {
63 self.bus.range_max_bits
64 }
65}
66
67impl<F: Field> BaseAirWithPublicValues<F> for VariableRangeCheckerAir {}
68impl<F: Field> PartitionedBaseAir<F> for VariableRangeCheckerAir {}
69impl<F: Field> BaseAir<F> for VariableRangeCheckerAir {
70 fn width(&self) -> usize {
71 NUM_VARIABLE_RANGE_COLS
72 }
73
74 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
75 let rows: Vec<F> = [F::ZERO; NUM_VARIABLE_RANGE_PREPROCESSED_COLS]
76 .into_iter()
77 .chain((0..=self.range_max_bits()).flat_map(|bits| {
78 (0..(1 << bits))
79 .flat_map(move |value| [F::from_u32(value), F::from_usize(bits)].into_iter())
80 }))
81 .collect();
82 Some(RowMajorMatrix::new(
83 rows,
84 NUM_VARIABLE_RANGE_PREPROCESSED_COLS,
85 ))
86 }
87}
88
89impl<AB: InteractionBuilder + PairBuilder> Air<AB> for VariableRangeCheckerAir {
90 fn eval(&self, builder: &mut AB) {
91 let preprocessed = builder.preprocessed();
92 let prep_local = preprocessed
93 .row_slice(0)
94 .expect("window should have two elements");
95 let prep_local: &VariableRangePreprocessedCols<AB::Var> = (*prep_local).borrow();
96 let main = builder.main();
97 let local = main.row_slice(0).expect("window should have two elements");
98 let local: &VariableRangeCols<AB::Var> = (*local).borrow();
99 self.bus
101 .receive(prep_local.value, prep_local.max_bits)
102 .eval(builder, local.mult);
103 }
104}
105
106pub struct VariableRangeCheckerChip {
107 pub air: VariableRangeCheckerAir,
108 pub count: Vec<AtomicU32>,
109}
110
111pub type SharedVariableRangeCheckerChip = Arc<VariableRangeCheckerChip>;
112
113impl VariableRangeCheckerChip {
114 pub fn new(bus: VariableRangeCheckerBus) -> Self {
115 let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
116 let count = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
117 Self {
118 air: VariableRangeCheckerAir::new(bus),
119 count,
120 }
121 }
122
123 pub fn bus(&self) -> VariableRangeCheckerBus {
124 self.air.bus
125 }
126
127 pub fn range_max_bits(&self) -> usize {
128 self.air.range_max_bits()
129 }
130
131 pub fn air_width(&self) -> usize {
132 NUM_VARIABLE_RANGE_COLS
133 }
134
135 #[instrument(
136 name = "VariableRangeCheckerChip::add_count",
137 skip(self),
138 level = "trace"
139 )]
140 pub fn add_count(&self, value: u32, max_bits: usize) {
141 let idx = (1 << max_bits) + (value as usize);
145 assert!(
146 idx < self.count.len(),
147 "range exceeded: {} >= {}",
148 idx,
149 self.count.len()
150 );
151 let val_atomic = &self.count[idx];
152 val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
153 }
154
155 pub fn clear(&self) {
156 for i in 0..self.count.len() {
157 self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
158 }
159 }
160
161 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
163 let mut rows = F::zero_vec(self.count.len() * NUM_VARIABLE_RANGE_COLS);
164 for (n, row) in rows.chunks_mut(NUM_VARIABLE_RANGE_COLS).enumerate() {
165 let cols: &mut VariableRangeCols<F> = row.borrow_mut();
166 cols.mult = F::from_u32(self.count[n].swap(0, std::sync::atomic::Ordering::Relaxed));
167 }
168 RowMajorMatrix::new(rows, NUM_VARIABLE_RANGE_COLS)
169 }
170
171 pub fn decompose<F: Field>(&self, mut value: u32, bits: usize, limbs: &mut [F]) {
174 debug_assert!(
175 limbs.len() >= bits.div_ceil(self.range_max_bits()),
176 "Not enough limbs: len {}",
177 limbs.len()
178 );
179 let mask = (1 << self.range_max_bits()) - 1;
180 let mut bits_remaining = bits;
181 for limb in limbs.iter_mut() {
182 let limb_u32 = value & mask;
183 *limb = F::from_u32(limb_u32);
184 self.add_count(limb_u32, bits_remaining.min(self.range_max_bits()));
185
186 value >>= self.range_max_bits();
187 bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
188 }
189 debug_assert_eq!(value, 0);
190 debug_assert_eq!(bits_remaining, 0);
191 }
192}
193
194impl<R, SC: StarkGenericConfig> Chip<R, CpuBackend<SC>> for VariableRangeCheckerChip
196where
197 Val<SC>: PrimeField32,
198{
199 fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
201 let trace = self.generate_trace::<Val<SC>>();
202 AirProvingContext::simple_no_pis(Arc::new(trace))
203 }
204}
205
206impl ChipUsageGetter for VariableRangeCheckerChip {
207 fn air_name(&self) -> String {
208 get_air_name(&self.air)
209 }
210 fn constant_trace_height(&self) -> Option<usize> {
211 Some(self.count.len())
212 }
213 fn current_trace_height(&self) -> usize {
214 self.count.len()
215 }
216 fn trace_width(&self) -> usize {
217 NUM_VARIABLE_RANGE_COLS
218 }
219}