openvm_circuit_primitives/var_range/
mod.rs1use core::mem::size_of;
7use std::{
8 borrow::{Borrow, BorrowMut},
9 sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14 config::{StarkGenericConfig, Val},
15 interaction::InteractionBuilder,
16 p3_air::{Air, BaseAir, PairBuilder},
17 p3_field::{Field, PrimeField32},
18 p3_matrix::{dense::RowMajorMatrix, Matrix},
19 prover::{cpu::CpuBackend, types::AirProvingContext},
20 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21 Chip, ChipUsageGetter,
22};
23use tracing::instrument;
24
25mod bus;
26pub use bus::*;
27
28#[cfg(feature = "cuda")]
29mod cuda;
30#[cfg(feature = "cuda")]
31pub use cuda::*;
32
33#[cfg(test)]
34pub mod tests;
35
36#[derive(Default, AlignedBorrow, Copy, Clone)]
37#[repr(C)]
38pub struct VariableRangeCols<T> {
39 pub mult: T,
41}
42
43#[derive(Default, AlignedBorrow, Copy, Clone)]
44#[repr(C)]
45pub struct VariableRangePreprocessedCols<T> {
46 pub value: T,
48 pub max_bits: T,
50}
51
52pub const NUM_VARIABLE_RANGE_COLS: usize = size_of::<VariableRangeCols<u8>>();
53pub const NUM_VARIABLE_RANGE_PREPROCESSED_COLS: usize =
54 size_of::<VariableRangePreprocessedCols<u8>>();
55
56#[derive(Clone, Copy, Debug, derive_new::new)]
57pub struct VariableRangeCheckerAir {
58 pub bus: VariableRangeCheckerBus,
59}
60
61impl VariableRangeCheckerAir {
62 pub fn range_max_bits(&self) -> usize {
63 self.bus.range_max_bits
64 }
65}
66
67impl<F: Field> BaseAirWithPublicValues<F> for VariableRangeCheckerAir {}
68impl<F: Field> PartitionedBaseAir<F> for VariableRangeCheckerAir {}
69impl<F: Field> BaseAir<F> for VariableRangeCheckerAir {
70 fn width(&self) -> usize {
71 NUM_VARIABLE_RANGE_COLS
72 }
73
74 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
75 let rows: Vec<F> = [F::ZERO; NUM_VARIABLE_RANGE_PREPROCESSED_COLS]
76 .into_iter()
77 .chain((0..=self.range_max_bits()).flat_map(|bits| {
78 (0..(1 << bits)).flat_map(move |value| {
79 [F::from_canonical_u32(value), F::from_canonical_usize(bits)].into_iter()
80 })
81 }))
82 .collect();
83 Some(RowMajorMatrix::new(
84 rows,
85 NUM_VARIABLE_RANGE_PREPROCESSED_COLS,
86 ))
87 }
88}
89
90impl<AB: InteractionBuilder + PairBuilder> Air<AB> for VariableRangeCheckerAir {
91 fn eval(&self, builder: &mut AB) {
92 let preprocessed = builder.preprocessed();
93 let prep_local = preprocessed.row_slice(0);
94 let prep_local: &VariableRangePreprocessedCols<AB::Var> = (*prep_local).borrow();
95 let main = builder.main();
96 let local = main.row_slice(0);
97 let local: &VariableRangeCols<AB::Var> = (*local).borrow();
98 self.bus
100 .receive(prep_local.value, prep_local.max_bits)
101 .eval(builder, local.mult);
102 }
103}
104
105pub struct VariableRangeCheckerChip {
106 pub air: VariableRangeCheckerAir,
107 pub count: Vec<AtomicU32>,
108}
109
110pub type SharedVariableRangeCheckerChip = Arc<VariableRangeCheckerChip>;
111
112impl VariableRangeCheckerChip {
113 pub fn new(bus: VariableRangeCheckerBus) -> Self {
114 let num_rows = (1 << (bus.range_max_bits + 1)) as usize;
115 let count = (0..num_rows).map(|_| AtomicU32::new(0)).collect();
116 Self {
117 air: VariableRangeCheckerAir::new(bus),
118 count,
119 }
120 }
121
122 pub fn bus(&self) -> VariableRangeCheckerBus {
123 self.air.bus
124 }
125
126 pub fn range_max_bits(&self) -> usize {
127 self.air.range_max_bits()
128 }
129
130 pub fn air_width(&self) -> usize {
131 NUM_VARIABLE_RANGE_COLS
132 }
133
134 #[instrument(
135 name = "VariableRangeCheckerChip::add_count",
136 skip(self),
137 level = "trace"
138 )]
139 pub fn add_count(&self, value: u32, max_bits: usize) {
140 let idx = (1 << max_bits) + (value as usize);
144 assert!(
145 idx < self.count.len(),
146 "range exceeded: {} >= {}",
147 idx,
148 self.count.len()
149 );
150 let val_atomic = &self.count[idx];
151 val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
152 }
153
154 pub fn clear(&self) {
155 for i in 0..self.count.len() {
156 self.count[i].store(0, std::sync::atomic::Ordering::Relaxed);
157 }
158 }
159
160 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
162 let mut rows = F::zero_vec(self.count.len() * NUM_VARIABLE_RANGE_COLS);
163 for (n, row) in rows.chunks_mut(NUM_VARIABLE_RANGE_COLS).enumerate() {
164 let cols: &mut VariableRangeCols<F> = row.borrow_mut();
165 cols.mult =
166 F::from_canonical_u32(self.count[n].swap(0, std::sync::atomic::Ordering::Relaxed));
167 }
168 RowMajorMatrix::new(rows, NUM_VARIABLE_RANGE_COLS)
169 }
170
171 pub fn decompose<F: Field>(&self, mut value: u32, bits: usize, limbs: &mut [F]) {
174 debug_assert!(
175 limbs.len() >= bits.div_ceil(self.range_max_bits()),
176 "Not enough limbs: len {}",
177 limbs.len()
178 );
179 let mask = (1 << self.range_max_bits()) - 1;
180 let mut bits_remaining = bits;
181 for limb in limbs.iter_mut() {
182 let limb_u32 = value & mask;
183 *limb = F::from_canonical_u32(limb_u32);
184 self.add_count(limb_u32, bits_remaining.min(self.range_max_bits()));
185
186 value >>= self.range_max_bits();
187 bits_remaining = bits_remaining.saturating_sub(self.range_max_bits());
188 }
189 debug_assert_eq!(value, 0);
190 debug_assert_eq!(bits_remaining, 0);
191 }
192}
193
194impl<R, SC: StarkGenericConfig> Chip<R, CpuBackend<SC>> for VariableRangeCheckerChip
196where
197 Val<SC>: PrimeField32,
198{
199 fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
201 let trace = self.generate_trace::<Val<SC>>();
202 AirProvingContext::simple_no_pis(Arc::new(trace))
203 }
204}
205
206impl ChipUsageGetter for VariableRangeCheckerChip {
207 fn air_name(&self) -> String {
208 get_air_name(&self.air)
209 }
210 fn constant_trace_height(&self) -> Option<usize> {
211 Some(self.count.len())
212 }
213 fn current_trace_height(&self) -> usize {
214 self.count.len()
215 }
216 fn trace_width(&self) -> usize {
217 NUM_VARIABLE_RANGE_COLS
218 }
219}