openvm_circuit_primitives/range_tuple/
mod.rs
1use std::{
7 mem::size_of,
8 sync::{atomic::AtomicU32, Arc},
9};
10
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_stark_backend::{
13 config::{StarkGenericConfig, Val},
14 interaction::InteractionBuilder,
15 p3_air::{Air, BaseAir, PairBuilder},
16 p3_field::{Field, PrimeField32},
17 p3_matrix::{dense::RowMajorMatrix, Matrix},
18 prover::types::AirProofInput,
19 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
20 AirRef, Chip, ChipUsageGetter,
21};
22
23mod bus;
24
25#[cfg(test)]
26pub mod tests;
27
28pub use bus::*;
29
30#[repr(C)]
31#[derive(Default, Copy, Clone, AlignedBorrow)]
32pub struct RangeTupleCols<T> {
33 pub mult: T,
35}
36
37#[derive(Default, Clone)]
38pub struct RangeTuplePreprocessedCols<T> {
39 pub tuple: Vec<T>,
41}
42
43pub const NUM_RANGE_TUPLE_COLS: usize = size_of::<RangeTupleCols<u8>>();
44
45#[derive(Clone, Copy, Debug)]
46pub struct RangeTupleCheckerAir<const N: usize> {
47 pub bus: RangeTupleCheckerBus<N>,
48}
49
50impl<const N: usize> RangeTupleCheckerAir<N> {
51 pub fn height(&self) -> u32 {
52 self.bus.sizes.iter().product()
53 }
54}
55impl<F: Field, const N: usize> BaseAirWithPublicValues<F> for RangeTupleCheckerAir<N> {}
56impl<F: Field, const N: usize> PartitionedBaseAir<F> for RangeTupleCheckerAir<N> {}
57
58impl<F: Field, const N: usize> BaseAir<F> for RangeTupleCheckerAir<N> {
59 fn width(&self) -> usize {
60 NUM_RANGE_TUPLE_COLS
61 }
62
63 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
64 let mut unrolled_matrix = Vec::with_capacity((self.height() as usize) * N);
65 let mut row = [0u32; N];
66 for _ in 0..self.height() {
67 unrolled_matrix.extend(row);
68 for i in (0..N).rev() {
69 if row[i] < self.bus.sizes[i] - 1 {
70 row[i] += 1;
71 break;
72 }
73 row[i] = 0;
74 }
75 }
76 Some(RowMajorMatrix::new(
77 unrolled_matrix
78 .iter()
79 .map(|&v| F::from_canonical_u32(v))
80 .collect(),
81 N,
82 ))
83 }
84}
85
86impl<AB: InteractionBuilder + PairBuilder, const N: usize> Air<AB> for RangeTupleCheckerAir<N> {
87 fn eval(&self, builder: &mut AB) {
88 let preprocessed = builder.preprocessed();
89 let prep_local = preprocessed.row_slice(0);
90 let prep_local = RangeTuplePreprocessedCols {
91 tuple: (*prep_local).to_vec(),
92 };
93 let main = builder.main();
94 let local = main.row_slice(0);
95 let local = RangeTupleCols { mult: (*local)[0] };
96
97 self.bus.receive(prep_local.tuple).eval(builder, local.mult);
98 }
99}
100
101#[derive(Debug)]
102pub struct RangeTupleCheckerChip<const N: usize> {
103 pub air: RangeTupleCheckerAir<N>,
104 pub count: Vec<Arc<AtomicU32>>,
105}
106
107#[derive(Debug, Clone)]
108pub struct SharedRangeTupleCheckerChip<const N: usize>(Arc<RangeTupleCheckerChip<N>>);
109
110impl<const N: usize> RangeTupleCheckerChip<N> {
111 pub fn new(bus: RangeTupleCheckerBus<N>) -> Self {
112 let range_max = bus.sizes.iter().product();
113 let count = (0..range_max)
114 .map(|_| Arc::new(AtomicU32::new(0)))
115 .collect();
116
117 Self {
118 air: RangeTupleCheckerAir { bus },
119 count,
120 }
121 }
122
123 pub fn bus(&self) -> &RangeTupleCheckerBus<N> {
124 &self.air.bus
125 }
126
127 pub fn sizes(&self) -> &[u32; N] {
128 &self.air.bus.sizes
129 }
130
131 pub fn add_count(&self, ids: &[u32]) {
132 let index = ids
133 .iter()
134 .zip(self.air.bus.sizes.iter())
135 .fold(0, |acc, (id, sz)| acc * sz + id) as usize;
136 assert!(
137 index < self.count.len(),
138 "range exceeded: {} >= {}",
139 index,
140 self.count.len()
141 );
142 let val_atomic = &self.count[index];
143 val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
144 }
145
146 pub fn clear(&self) {
147 for val in &self.count {
148 val.store(0, std::sync::atomic::Ordering::Relaxed);
149 }
150 }
151
152 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
153 let rows = self
154 .count
155 .iter()
156 .map(|c| F::from_canonical_u32(c.load(std::sync::atomic::Ordering::SeqCst)))
157 .collect::<Vec<_>>();
158 RowMajorMatrix::new(rows, 1)
159 }
160}
161
162impl<const N: usize> SharedRangeTupleCheckerChip<N> {
163 pub fn new(bus: RangeTupleCheckerBus<N>) -> Self {
164 Self(Arc::new(RangeTupleCheckerChip::new(bus)))
165 }
166 pub fn bus(&self) -> &RangeTupleCheckerBus<N> {
167 self.0.bus()
168 }
169
170 pub fn sizes(&self) -> &[u32; N] {
171 self.0.sizes()
172 }
173
174 pub fn add_count(&self, ids: &[u32]) {
175 self.0.add_count(ids);
176 }
177
178 pub fn clear(&self) {
179 self.0.clear();
180 }
181
182 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
183 self.0.generate_trace()
184 }
185}
186
187impl<SC: StarkGenericConfig, const N: usize> Chip<SC> for RangeTupleCheckerChip<N>
188where
189 Val<SC>: PrimeField32,
190{
191 fn air(&self) -> AirRef<SC> {
192 Arc::new(self.air)
193 }
194
195 fn generate_air_proof_input(self) -> AirProofInput<SC> {
196 let trace = self.generate_trace::<Val<SC>>();
197 AirProofInput::simple_no_pis(trace)
198 }
199}
200
201impl<SC: StarkGenericConfig, const N: usize> Chip<SC> for SharedRangeTupleCheckerChip<N>
202where
203 Val<SC>: PrimeField32,
204{
205 fn air(&self) -> AirRef<SC> {
206 self.0.air()
207 }
208
209 fn generate_air_proof_input(self) -> AirProofInput<SC> {
210 self.0.generate_air_proof_input()
211 }
212}
213
214impl<const N: usize> ChipUsageGetter for RangeTupleCheckerChip<N> {
215 fn air_name(&self) -> String {
216 get_air_name(&self.air)
217 }
218 fn constant_trace_height(&self) -> Option<usize> {
219 Some(self.count.len())
220 }
221 fn current_trace_height(&self) -> usize {
222 self.count.len()
223 }
224 fn trace_width(&self) -> usize {
225 NUM_RANGE_TUPLE_COLS
226 }
227}
228
229impl<const N: usize> ChipUsageGetter for SharedRangeTupleCheckerChip<N> {
230 fn air_name(&self) -> String {
231 self.0.air_name()
232 }
233
234 fn constant_trace_height(&self) -> Option<usize> {
235 self.0.constant_trace_height()
236 }
237
238 fn current_trace_height(&self) -> usize {
239 self.0.current_trace_height()
240 }
241
242 fn trace_width(&self) -> usize {
243 self.0.trace_width()
244 }
245}
246
247impl<const N: usize> AsRef<RangeTupleCheckerChip<N>> for SharedRangeTupleCheckerChip<N> {
248 fn as_ref(&self) -> &RangeTupleCheckerChip<N> {
249 &self.0
250 }
251}