openvm_circuit_primitives/var_range/
mod.rs
1use core::mem::size_of;
7use std::{
8 borrow::{Borrow, BorrowMut},
9 sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14 config::{StarkGenericConfig, Val},
15 interaction::InteractionBuilder,
16 p3_air::{Air, BaseAir, PairBuilder},
17 p3_field::{Field, PrimeField32},
18 p3_matrix::{dense::RowMajorMatrix, Matrix},
19 prover::types::AirProofInput,
20 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21 AirRef, Chip, ChipUsageGetter,
22};
23use tracing::instrument;
24
25mod bus;
26#[cfg(test)]
27pub mod tests;
28
29pub use bus::*;
30
31#[derive(Default, AlignedBorrow, Copy, Clone)]
32#[repr(C)]
33pub struct VariableRangeCols<T> {
34 pub mult: T,
36}
37
38#[derive(Default, AlignedBorrow, Copy, Clone)]
39#[repr(C)]
40pub struct VariableRangePreprocessedCols<T> {
41 pub value: T,
43 pub max_bits: T,
45}
46
47pub const NUM_VARIABLE_RANGE_COLS: usize = size_of::<VariableRangeCols<u8>>();
48pub const NUM_VARIABLE_RANGE_PREPROCESSED_COLS: usize =
49 size_of::<VariableRangePreprocessedCols<u8>>();
50
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct VariableRangeCheckerAir {
53 pub bus: VariableRangeCheckerBus,
54}
55
56impl VariableRangeCheckerAir {
57 pub fn range_max_bits(&self) -> usize {
58 self.bus.range_max_bits
59 }
60}
61
62impl<F: Field> BaseAirWithPublicValues<F> for VariableRangeCheckerAir {}
63impl<F: Field> PartitionedBaseAir<F> for VariableRangeCheckerAir {}
64impl<F: Field> BaseAir<F> for VariableRangeCheckerAir {
65 fn width(&self) -> usize {
66 NUM_VARIABLE_RANGE_COLS
67 }
68
69 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
70 let rows: Vec<F> = [F::ZERO; NUM_VARIABLE_RANGE_PREPROCESSED_COLS]
71 .into_iter()
72 .chain((0..=self.range_max_bits()).flat_map(|bits| {
73 (0..(1 << bits)).flat_map(move |value| {
74 [F::from_canonical_u32(value), F::from_canonical_usize(bits)].into_iter()
75 })
76 }))
77 .collect();
78 Some(RowMajorMatrix::new(
79 rows,
80 NUM_VARIABLE_RANGE_PREPROCESSED_COLS,
81 ))
82 }
83}
84
85impl<AB: InteractionBuilder + PairBuilder> Air<AB> for VariableRangeCheckerAir {
86 fn eval(&self, builder: &mut AB) {
87 let preprocessed = builder.preprocessed();
88 let prep_local = preprocessed.row_slice(0);
89 let prep_local: &VariableRangePreprocessedCols<AB::Var> = (*prep_local).borrow();
90 let main = builder.main();
91 let local = main.row_slice(0);
92 let local: &VariableRangeCols<AB::Var> = (*local).borrow();
93 self.bus
95 .receive(prep_local.value, prep_local.max_bits)
96 .eval(builder, local.mult);
97 }
98}
99
100pub struct VariableRangeCheckerChip {
101 pub air: VariableRangeCheckerAir,
102 pub count: Vec<AtomicU32>,
103}
104
105#[derive(Clone)]
106pub struct SharedVariableRangeCheckerChip(Arc<VariableRangeCheckerChip>);
107
108impl VariableRangeCheckerChip {
109 pub fn new(bus: VariableRangeCheckerBus) -> Self {
110 let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
111 let count = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
112 Self {
113 air: VariableRangeCheckerAir::new(bus),
114 count,
115 }
116 }
117
118 pub fn bus(&self) -> VariableRangeCheckerBus {
119 self.air.bus
120 }
121
122 pub fn range_max_bits(&self) -> usize {
123 self.air.range_max_bits()
124 }
125
126 pub fn air_width(&self) -> usize {
127 NUM_VARIABLE_RANGE_COLS
128 }
129
130 #[instrument(
131 name = "VariableRangeCheckerChip::add_count",
132 skip(self),
133 level = "trace"
134 )]
135 pub fn add_count(&self, value: u32, max_bits: usize) {
136 let idx = (1 << max_bits) + (value as usize);
139 assert!(
140 idx < self.count.len(),
141 "range exceeded: {} >= {}",
142 idx,
143 self.count.len()
144 );
145 let val_atomic = &self.count[idx];
146 val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147 }
148
149 pub fn clear(&self) {
150 for i in 0..self.count.len() {
151 self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
152 }
153 }
154
155 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
156 let mut rows = F::zero_vec(self.count.len() * NUM_VARIABLE_RANGE_COLS);
157 for (n, row) in rows.chunks_mut(NUM_VARIABLE_RANGE_COLS).enumerate() {
158 let cols: &mut VariableRangeCols<F> = row.borrow_mut();
159 cols.mult =
160 F::from_canonical_u32(self.count[n].load(std::sync::atomic::Ordering::SeqCst));
161 }
162 RowMajorMatrix::new(rows, NUM_VARIABLE_RANGE_COLS)
163 }
164
165 pub fn decompose<F: Field>(&self, mut value: u32, bits: usize, limbs: &mut [F]) {
168 debug_assert!(
169 limbs.len() >= bits.div_ceil(self.range_max_bits()),
170 "Not enough limbs: len {}",
171 limbs.len()
172 );
173 let mask = (1 << self.range_max_bits()) - 1;
174 let mut bits_remaining = bits;
175 for limb in limbs.iter_mut() {
176 let limb_u32 = value & mask;
177 *limb = F::from_canonical_u32(limb_u32);
178 self.add_count(limb_u32, bits_remaining.min(self.range_max_bits()));
179
180 value >>= self.range_max_bits();
181 bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
182 }
183 debug_assert_eq!(value, 0);
184 debug_assert_eq!(bits_remaining, 0);
185 }
186}
187
188impl SharedVariableRangeCheckerChip {
189 pub fn new(bus: VariableRangeCheckerBus) -> Self {
190 Self(Arc::new(VariableRangeCheckerChip::new(bus)))
191 }
192
193 pub fn bus(&self) -> VariableRangeCheckerBus {
194 self.0.bus()
195 }
196
197 pub fn range_max_bits(&self) -> usize {
198 self.0.range_max_bits()
199 }
200
201 pub fn air_width(&self) -> usize {
202 self.0.air_width()
203 }
204
205 pub fn add_count(&self, value: u32, max_bits: usize) {
206 self.0.add_count(value, max_bits)
207 }
208
209 pub fn clear(&self) {
210 self.0.clear()
211 }
212
213 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
214 self.0.generate_trace()
215 }
216}
217
218impl<SC: StarkGenericConfig> Chip<SC> for VariableRangeCheckerChip
219where
220 Val<SC>: PrimeField32,
221{
222 fn air(&self) -> AirRef<SC> {
223 Arc::new(self.air)
224 }
225
226 fn generate_air_proof_input(self) -> AirProofInput<SC> {
227 let trace = self.generate_trace::<Val<SC>>();
228 AirProofInput::simple_no_pis(trace)
229 }
230}
231
232impl<SC: StarkGenericConfig> Chip<SC> for SharedVariableRangeCheckerChip
233where
234 Val<SC>: PrimeField32,
235{
236 fn air(&self) -> AirRef<SC> {
237 self.0.air()
238 }
239
240 fn generate_air_proof_input(self) -> AirProofInput<SC> {
241 self.0.generate_air_proof_input()
242 }
243}
244
245impl ChipUsageGetter for VariableRangeCheckerChip {
246 fn air_name(&self) -> String {
247 get_air_name(&self.air)
248 }
249 fn constant_trace_height(&self) -> Option<usize> {
250 Some(self.count.len())
251 }
252 fn current_trace_height(&self) -> usize {
253 self.count.len()
254 }
255 fn trace_width(&self) -> usize {
256 NUM_VARIABLE_RANGE_COLS
257 }
258}
259
260impl ChipUsageGetter for SharedVariableRangeCheckerChip {
261 fn air_name(&self) -> String {
262 self.0.air_name()
263 }
264
265 fn constant_trace_height(&self) -> Option<usize> {
266 self.0.constant_trace_height()
267 }
268
269 fn current_trace_height(&self) -> usize {
270 self.0.current_trace_height()
271 }
272
273 fn trace_width(&self) -> usize {
274 self.0.trace_width()
275 }
276}
277
278impl AsRef<VariableRangeCheckerChip> for SharedVariableRangeCheckerChip {
279 fn as_ref(&self) -> &VariableRangeCheckerChip {
280 &self.0
281 }
282}