openvm_circuit_primitives/range_tuple/
mod.rs1use std::{
8 mem::size_of,
9 sync::{atomic::AtomicU32, Arc},
10};
11
12use openvm_circuit_primitives_derive::AlignedBorrow;
13use openvm_stark_backend::{
14 config::{StarkGenericConfig, Val},
15 interaction::InteractionBuilder,
16 p3_air::{Air, BaseAir, PairBuilder},
17 p3_field::{Field, PrimeField32},
18 p3_matrix::{dense::RowMajorMatrix, Matrix},
19 prover::{cpu::CpuBackend, types::AirProvingContext},
20 rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir},
21 Chip, ChipUsageGetter,
22};
23
24mod bus;
25pub use bus::*;
26
27#[cfg(feature = "cuda")]
28mod cuda;
29#[cfg(feature = "cuda")]
30pub use cuda::*;
31
32#[cfg(test)]
33pub mod tests;
34
35#[repr(C)]
36#[derive(Default, Copy, Clone, AlignedBorrow)]
37pub struct RangeTupleCols<T> {
38 pub mult: T,
40}
41
42#[derive(Default, Clone)]
43pub struct RangeTuplePreprocessedCols<T> {
44 pub tuple: Vec<T>,
46}
47
48pub const NUM_RANGE_TUPLE_COLS: usize = size_of::<RangeTupleCols<u8>>();
49
50#[derive(Clone, Copy, Debug)]
51pub struct RangeTupleCheckerAir<const N: usize> {
52 pub bus: RangeTupleCheckerBus<N>,
53}
54
55impl<const N: usize> RangeTupleCheckerAir<N> {
56 pub fn height(&self) -> u32 {
57 self.bus.sizes.iter().product()
58 }
59}
60impl<F: Field, const N: usize> BaseAirWithPublicValues<F> for RangeTupleCheckerAir<N> {}
61impl<F: Field, const N: usize> PartitionedBaseAir<F> for RangeTupleCheckerAir<N> {}
62
63impl<F: Field, const N: usize> BaseAir<F> for RangeTupleCheckerAir<N> {
64 fn width(&self) -> usize {
65 NUM_RANGE_TUPLE_COLS
66 }
67
68 fn preprocessed_trace(&self) -> Option<RowMajorMatrix<F>> {
69 let mut unrolled_matrix = Vec::with_capacity((self.height() as usize) * N);
70 let mut row = [0u32; N];
71 for _ in 0..self.height() {
72 unrolled_matrix.extend(row);
73 for i in (0..N).rev() {
74 if row[i] < self.bus.sizes[i] - 1 {
75 row[i] += 1;
76 break;
77 }
78 row[i] = 0;
79 }
80 }
81 Some(RowMajorMatrix::new(
82 unrolled_matrix.iter().map(|&v| F::from_u32(v)).collect(),
83 N,
84 ))
85 }
86}
87
88impl<AB: InteractionBuilder + PairBuilder, const N: usize> Air<AB> for RangeTupleCheckerAir<N> {
89 fn eval(&self, builder: &mut AB) {
90 let preprocessed = builder.preprocessed();
91 let prep_local = preprocessed.row_slice(0).unwrap();
92 let prep_local = RangeTuplePreprocessedCols {
93 tuple: (*prep_local).to_vec(),
94 };
95 let main = builder.main();
96 let local = main.row_slice(0).expect("window should have two elements");
97 let local = RangeTupleCols { mult: (*local)[0] };
98
99 self.bus.receive(prep_local.tuple).eval(builder, local.mult);
100 }
101}
102
103#[derive(Debug)]
104pub struct RangeTupleCheckerChip<const N: usize> {
105 pub air: RangeTupleCheckerAir<N>,
106 pub count: Vec<Arc<AtomicU32>>,
107}
108
109pub type SharedRangeTupleCheckerChip<const N: usize> = Arc<RangeTupleCheckerChip<N>>;
110
111impl<const N: usize> RangeTupleCheckerChip<N> {
112 pub fn new(bus: RangeTupleCheckerBus<N>) -> Self {
113 let range_max = bus.sizes.iter().product();
114 let count = (0..range_max)
115 .map(|_| Arc::new(AtomicU32::new(0)))
116 .collect();
117
118 Self {
119 air: RangeTupleCheckerAir { bus },
120 count,
121 }
122 }
123
124 pub fn bus(&self) -> &RangeTupleCheckerBus<N> {
125 &self.air.bus
126 }
127
128 pub fn sizes(&self) -> &[u32; N] {
129 &self.air.bus.sizes
130 }
131
132 pub fn add_count(&self, ids: &[u32]) {
133 let index = ids
134 .iter()
135 .zip(self.air.bus.sizes.iter())
136 .fold(0, |acc, (id, sz)| acc * sz + id) as usize;
137 assert!(
138 index < self.count.len(),
139 "range exceeded: {} >= {}",
140 index,
141 self.count.len()
142 );
143 let val_atomic = &self.count[index];
144 val_atomic.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145 }
146
147 pub fn clear(&self) {
148 for val in &self.count {
149 val.store(0, std::sync::atomic::Ordering::Relaxed);
150 }
151 }
152
153 pub fn generate_trace<F: Field>(&self) -> RowMajorMatrix<F> {
154 let rows = self
155 .count
156 .iter()
157 .map(|c| F::from_u32(c.swap(0, std::sync::atomic::Ordering::Relaxed)))
158 .collect::<Vec<_>>();
159 RowMajorMatrix::new(rows, 1)
160 }
161}
162
163impl<R, SC: StarkGenericConfig, const N: usize> Chip<R, CpuBackend<SC>> for RangeTupleCheckerChip<N>
164where
165 Val<SC>: PrimeField32,
166{
167 fn generate_proving_ctx(&self, _: R) -> AirProvingContext<CpuBackend<SC>> {
168 let trace = self.generate_trace::<Val<SC>>();
169 AirProvingContext::simple_no_pis(Arc::new(trace))
170 }
171}
172
173impl<const N: usize> ChipUsageGetter for RangeTupleCheckerChip<N> {
174 fn air_name(&self) -> String {
175 get_air_name(&self.air)
176 }
177 fn constant_trace_height(&self) -> Option<usize> {
178 Some(self.count.len())
179 }
180 fn current_trace_height(&self) -> usize {
181 self.count.len()
182 }
183 fn trace_width(&self) -> usize {
184 NUM_RANGE_TUPLE_COLS
185 }
186}